import argparse import json import logging import os import random import numpy as np import pandas as pd import torch from sentence_transformers import LoggingHandler, SentenceTransformer from sklearn.model_selection import GroupShuffleSplit, train_test_split from torch import nn from torch.utils.data import DataLoader, Dataset logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()], ) class SignCoordinateDataset(Dataset): def __init__(self, texts, coordinates): self.texts = list(texts) self.coordinates = torch.tensor(coordinates, dtype=torch.float32) def __len__(self): return len(self.texts) def __getitem__(self, index): return self.texts[index], self.coordinates[index] class CoordinateRegressor(nn.Module): def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1): super().__init__() self.layers = nn.Sequential( nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 2), ) def forward(self, embeddings): return self.layers(embeddings) def parse_args(): parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.") parser.add_argument("--data-file", default="training.csv") parser.add_argument("--output-path", default="output") parser.add_argument( "--model-name", default="sentence-transformers/all-MiniLM-L6-v2" ) parser.add_argument("--device", default=None) parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--head-learning-rate", type=float, default=1e-3) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--test-size", type=float, default=0.2) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) return parser.parse_args() def get_device(requested_device): if requested_device: return torch.device(requested_device) if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def normalize_coordinates(coordinates): mean = coordinates.mean(axis=0) std = coordinates.std(axis=0) std[std == 0] = 1.0 return (coordinates - mean) / std, mean, std def collate_fn(model, device): def collate(batch): texts, labels = zip(*batch) features = model.tokenize(list(texts)) features = {key: value.to(device) for key, value in features.items()} return features, torch.stack(labels).to(device) return collate def train_epoch(encoder, head, dataloader, optimizer, loss_fn): encoder.train() head.train() total_loss = 0.0 for features, labels in dataloader: optimizer.zero_grad() embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) loss.backward() optimizer.step() total_loss += loss.item() * labels.size(0) return total_loss / len(dataloader.dataset) @torch.no_grad() def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std): encoder.eval() head.eval() total_loss = 0.0 errors_km = [] for features, labels in dataloader: embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) total_loss += loss.item() * labels.size(0) pred_coords = predictions.cpu().numpy() * coord_std + coord_mean true_coords = labels.cpu().numpy() * coord_std + coord_mean errors_km.extend(haversine_km(pred_coords, true_coords)) return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) def haversine_km(pred_coords, true_coords): lat1 = np.radians(pred_coords[:, 0]) lon1 = np.radians(pred_coords[:, 1]) lat2 = np.radians(true_coords[:, 0]) lon2 = np.radians(true_coords[:, 1]) dlat = lat2 - lat1 dlon = lon2 - lon1 a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2 return 2 * 6371.0088 * np.arcsin(np.sqrt(a)) def save_model(output_path, encoder, head, coord_mean, coord_std, args): os.makedirs(output_path, exist_ok=True) encoder.save(output_path) torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt")) metadata = { "coord_mean": coord_mean.tolist(), "coord_std": coord_std.tolist(), "hidden_dim": args.hidden_dim, "dropout": args.dropout, "model_name": args.model_name, } with open(os.path.join(output_path, "coordinate_config.json"), "w") as f: json.dump(metadata, f, indent=2) def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args): os.makedirs(output_path, exist_ok=True) encoder.save(os.path.join(output_path, "initial_encoder")) torch.save( head.state_dict(), os.path.join(output_path, "initial_coordinate_head.pt"), ) metadata = { "coord_mean": coord_mean.tolist(), "coord_std": coord_std.tolist(), "hidden_dim": args.hidden_dim, "dropout": args.dropout, "model_name": args.model_name, } with open(os.path.join(output_path, "coordinate_config.json"), "w") as f: json.dump(metadata, f, indent=2) def main(): args = parse_args() set_seed(args.seed) device = get_device(args.device) print(f"Using device: {device}") data = pd.read_csv(args.data_file) data = data.dropna(subset=["text", "latitude", "longitude"]) texts = data["text"].astype(str).tolist() coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32) normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates) indices = np.arange(len(data)) if "intersection" in data.columns: splitter = GroupShuffleSplit( n_splits=1, test_size=args.test_size, random_state=args.seed ) train_indices, val_indices = next( splitter.split(indices, groups=data["intersection"]) ) else: train_indices, val_indices = train_test_split( indices, test_size=args.test_size, random_state=args.seed ) train_dataset = SignCoordinateDataset( [texts[i] for i in train_indices], normalized_coordinates[train_indices] ) val_dataset = SignCoordinateDataset( [texts[i] for i in val_indices], normalized_coordinates[val_indices] ) encoder = SentenceTransformer(args.model_name, device=str(device)) embedding_dim = encoder.get_sentence_embedding_dimension() head = CoordinateRegressor( embedding_dim=embedding_dim, hidden_dim=args.hidden_dim, dropout=args.dropout, ).to(device) save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn(encoder, device), ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn(encoder, device), ) optimizer = torch.optim.AdamW( [ {"params": encoder.parameters(), "lr": args.learning_rate}, {"params": head.parameters(), "lr": args.head_learning_rate}, ], weight_decay=args.weight_decay, ) loss_fn = nn.MSELoss() best_val_loss = float("inf") print( f"Training on {len(train_dataset):,} rows; " f"validating on {len(val_dataset):,} rows" ) for epoch in range(1, args.epochs + 1): train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn) val_loss, val_error_km = evaluate( encoder, head, val_loader, loss_fn, coord_mean, coord_std ) print( f"epoch={epoch} train_loss={train_loss:.6f} " f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}" ) if val_loss < best_val_loss: best_val_loss = val_loss save_model(args.output_path, encoder, head, coord_mean, coord_std, args) print(f"Saved best model to {args.output_path}") if __name__ == "__main__": main()