import argparse import json import logging import os import random import numpy as np import pandas as pd import torch from sentence_transformers import LoggingHandler, SentenceTransformer from sklearn.model_selection import GroupShuffleSplit, train_test_split from torch import nn from torch.utils.data import DataLoader, Dataset logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()], ) class SignCoordinateDataset(Dataset): def __init__(self, texts, coordinates): self.texts = list(texts) self.coordinates = torch.tensor(coordinates, dtype=torch.float32) def __len__(self): return len(self.texts) def __getitem__(self, index): return self.texts[index], self.coordinates[index] class EmbeddingCoordinateDataset(Dataset): def __init__(self, embeddings, coordinates): self.embeddings = torch.tensor(embeddings, dtype=torch.float32) self.coordinates = torch.tensor(coordinates, dtype=torch.float32) def __len__(self): return len(self.embeddings) def __getitem__(self, index): return self.embeddings[index], self.coordinates[index] class CoordinateRegressor(nn.Module): def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1): super().__init__() self.layers = nn.Sequential( nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 2), ) def forward(self, embeddings): return self.layers(embeddings) def parse_args(): parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.") parser.add_argument("--data-file", default="training.csv") parser.add_argument("--output-path", default="output") parser.add_argument( "--model-name", default="sentence-transformers/all-MiniLM-L6-v2" ) parser.add_argument("--device", default=None) parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--head-learning-rate", type=float, default=1e-3) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--test-size", type=float, default=0.2) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument( "--freeze-encoder", action="store_true", help="Train only the coordinate head; keep the sentence encoder fixed.", ) parser.add_argument( "--freeze-transformer-layers", type=int, default=0, help="Freeze the first N transformer layers in the sentence encoder.", ) parser.add_argument( "--freeze-attention", action="store_true", help=( "Freeze self-attention parameters while leaving other encoder " "params trainable." ), ) return parser.parse_args() def get_device(requested_device): if requested_device: return torch.device(requested_device) if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def normalize_coordinates(coordinates): mean = coordinates.mean(axis=0) std = coordinates.std(axis=0) std[std == 0] = 1.0 return (coordinates - mean) / std, mean, std def collate_fn(model, device): def collate(batch): texts, labels = zip(*batch) features = model.tokenize(list(texts)) features = {key: value.to(device) for key, value in features.items()} return features, torch.stack(labels).to(device) return collate @torch.no_grad() def encode_texts(encoder, texts, batch_size, device): encoder.eval() embeddings = [] for start in range(0, len(texts), batch_size): batch = texts[start : start + batch_size] features = encoder.tokenize(batch) features = {key: value.to(device) for key, value in features.items()} batch_embeddings = encoder(features)["sentence_embedding"] embeddings.append(batch_embeddings.cpu().numpy()) return np.vstack(embeddings) def train_head_epoch(head, dataloader, optimizer, loss_fn, device): head.train() total_loss = 0.0 for embeddings, labels in dataloader: embeddings = embeddings.to(device) labels = labels.to(device) optimizer.zero_grad() predictions = head(embeddings) loss = loss_fn(predictions, labels) loss.backward() optimizer.step() total_loss += loss.item() * labels.size(0) return total_loss / len(dataloader.dataset) @torch.no_grad() def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device): head.eval() total_loss = 0.0 errors_km = [] for embeddings, labels in dataloader: embeddings = embeddings.to(device) labels = labels.to(device) predictions = head(embeddings) loss = loss_fn(predictions, labels) total_loss += loss.item() * labels.size(0) pred_coords = predictions.cpu().numpy() * coord_std + coord_mean true_coords = labels.cpu().numpy() * coord_std + coord_mean errors_km.extend(haversine_km(pred_coords, true_coords)) return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable): if encoder_trainable: encoder.train() else: encoder.eval() head.train() total_loss = 0.0 for features, labels in dataloader: optimizer.zero_grad() if encoder_trainable: embeddings = encoder(features)["sentence_embedding"] else: with torch.no_grad(): embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) loss.backward() optimizer.step() total_loss += loss.item() * labels.size(0) return total_loss / len(dataloader.dataset) @torch.no_grad() def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std): encoder.eval() head.eval() total_loss = 0.0 errors_km = [] for features, labels in dataloader: embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) total_loss += loss.item() * labels.size(0) pred_coords = predictions.cpu().numpy() * coord_std + coord_mean true_coords = labels.cpu().numpy() * coord_std + coord_mean errors_km.extend(haversine_km(pred_coords, true_coords)) return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) def haversine_km(pred_coords, true_coords): lat1 = np.radians(pred_coords[:, 0]) lon1 = np.radians(pred_coords[:, 1]) lat2 = np.radians(true_coords[:, 0]) lon2 = np.radians(true_coords[:, 1]) dlat = lat2 - lat1 dlon = lon2 - lon1 a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2 return 2 * 6371.0088 * np.arcsin(np.sqrt(a)) def save_model(output_path, encoder, head, coord_mean, coord_std, args): os.makedirs(output_path, exist_ok=True) encoder.save(output_path) torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt")) metadata = { "coord_mean": coord_mean.tolist(), "coord_std": coord_std.tolist(), "hidden_dim": args.hidden_dim, "dropout": args.dropout, "model_name": args.model_name, } with open(os.path.join(output_path, "coordinate_config.json"), "w") as f: json.dump(metadata, f, indent=2) def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args): os.makedirs(output_path, exist_ok=True) encoder.save(os.path.join(output_path, "initial_encoder")) torch.save( head.state_dict(), os.path.join(output_path, "initial_coordinate_head.pt"), ) metadata = { "coord_mean": coord_mean.tolist(), "coord_std": coord_std.tolist(), "hidden_dim": args.hidden_dim, "dropout": args.dropout, "model_name": args.model_name, } with open(os.path.join(output_path, "coordinate_config.json"), "w") as f: json.dump(metadata, f, indent=2) def freeze_encoder_parts(encoder, args): if args.freeze_encoder: for parameter in encoder.parameters(): parameter.requires_grad = False return transformer = encoder[0].auto_model if args.freeze_transformer_layers > 0: layers = transformer.encoder.layer[: args.freeze_transformer_layers] for layer in layers: for parameter in layer.parameters(): parameter.requires_grad = False if args.freeze_attention: for name, parameter in transformer.named_parameters(): if ".attention." in name or name.startswith("attention."): parameter.requires_grad = False def count_trainable_parameters(module): trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) total = sum(p.numel() for p in module.parameters()) return trainable, total def make_optimizer(encoder, head, args): parameter_groups = [] encoder_parameters = [p for p in encoder.parameters() if p.requires_grad] if encoder_parameters: group = {"params": encoder_parameters, "lr": args.learning_rate} parameter_groups.append(group) parameter_groups.append( {"params": head.parameters(), "lr": args.head_learning_rate} ) return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay) def main(): args = parse_args() set_seed(args.seed) device = get_device(args.device) print(f"Using device: {device}") data = pd.read_csv(args.data_file) data = data.dropna(subset=["text", "latitude", "longitude"]) texts = data["text"].astype(str).tolist() coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32) normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates) indices = np.arange(len(data)) if "intersection" in data.columns: splitter = GroupShuffleSplit( n_splits=1, test_size=args.test_size, random_state=args.seed ) train_indices, val_indices = next( splitter.split(indices, groups=data["intersection"]) ) else: train_indices, val_indices = train_test_split( indices, test_size=args.test_size, random_state=args.seed ) train_dataset = SignCoordinateDataset( [texts[i] for i in train_indices], normalized_coordinates[train_indices] ) val_dataset = SignCoordinateDataset( [texts[i] for i in val_indices], normalized_coordinates[val_indices] ) encoder = SentenceTransformer(args.model_name, device=str(device)) embedding_dim = encoder.get_sentence_embedding_dimension() head = CoordinateRegressor( embedding_dim=embedding_dim, hidden_dim=args.hidden_dim, dropout=args.dropout, ).to(device) save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args) freeze_encoder_parts(encoder, args) encoder_trainable, encoder_total = count_trainable_parameters(encoder) head_trainable, head_total = count_trainable_parameters(head) print( f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; " f"head params: {head_trainable:,}/{head_total:,}" ) if encoder_trainable == 0: print("Caching frozen encoder embeddings...") all_embeddings = encode_texts(encoder, texts, args.batch_size, device) train_dataset = EmbeddingCoordinateDataset( all_embeddings[train_indices], normalized_coordinates[train_indices] ) val_dataset = EmbeddingCoordinateDataset( all_embeddings[val_indices], normalized_coordinates[val_indices] ) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, ) else: train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn(encoder, device), ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn(encoder, device), ) optimizer = make_optimizer(encoder, head, args) loss_fn = nn.MSELoss() best_val_loss = float("inf") print( f"Training on {len(train_dataset):,} rows; " f"validating on {len(val_dataset):,} rows" ) for epoch in range(1, args.epochs + 1): if encoder_trainable == 0: train_loss = train_head_epoch( head, train_loader, optimizer, loss_fn, device ) val_loss, val_error_km = evaluate_head( head, val_loader, loss_fn, coord_mean, coord_std, device ) else: train_loss = train_epoch( encoder, head, train_loader, optimizer, loss_fn, True, ) val_loss, val_error_km = evaluate( encoder, head, val_loader, loss_fn, coord_mean, coord_std ) print( f"epoch={epoch} train_loss={train_loss:.6f} " f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}" ) if val_loss < best_val_loss: best_val_loss = val_loss save_model(args.output_path, encoder, head, coord_mean, coord_std, args) print(f"Saved best model to {args.output_path}") if __name__ == "__main__": main()