citybert/train.py
2026-05-25 14:11:05 -06:00

274 lines
8.8 KiB
Python

import argparse
import json
import logging
import os
import random
import numpy as np
import pandas as pd
import torch
from sentence_transformers import LoggingHandler, SentenceTransformer
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from torch import nn
from torch.utils.data import DataLoader, Dataset
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
class SignCoordinateDataset(Dataset):
def __init__(self, texts, coordinates):
self.texts = list(texts)
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
def __len__(self):
return len(self.texts)
def __getitem__(self, index):
return self.texts[index], self.coordinates[index]
class CoordinateRegressor(nn.Module):
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 2),
)
def forward(self, embeddings):
return self.layers(embeddings)
def parse_args():
parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.")
parser.add_argument("--data-file", default="training.csv")
parser.add_argument("--output-path", default="output")
parser.add_argument(
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
)
parser.add_argument("--device", default=None)
parser.add_argument("--seed", type=int, default=1992)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--learning-rate", type=float, default=2e-5)
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--test-size", type=float, default=0.2)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
return parser.parse_args()
def get_device(requested_device):
if requested_device:
return torch.device(requested_device)
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def normalize_coordinates(coordinates):
mean = coordinates.mean(axis=0)
std = coordinates.std(axis=0)
std[std == 0] = 1.0
return (coordinates - mean) / std, mean, std
def collate_fn(model, device):
def collate(batch):
texts, labels = zip(*batch)
features = model.tokenize(list(texts))
features = {key: value.to(device) for key, value in features.items()}
return features, torch.stack(labels).to(device)
return collate
def train_epoch(encoder, head, dataloader, optimizer, loss_fn):
encoder.train()
head.train()
total_loss = 0.0
for features, labels in dataloader:
optimizer.zero_grad()
embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
return total_loss / len(dataloader.dataset)
@torch.no_grad()
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std):
encoder.eval()
head.eval()
total_loss = 0.0
errors_km = []
for features, labels in dataloader:
embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0)
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
true_coords = labels.cpu().numpy() * coord_std + coord_mean
errors_km.extend(haversine_km(pred_coords, true_coords))
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
def haversine_km(pred_coords, true_coords):
lat1 = np.radians(pred_coords[:, 0])
lon1 = np.radians(pred_coords[:, 1])
lat2 = np.radians(true_coords[:, 0])
lon2 = np.radians(true_coords[:, 1])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
return 2 * 6371.0088 * np.arcsin(np.sqrt(a))
def save_model(output_path, encoder, head, coord_mean, coord_std, args):
os.makedirs(output_path, exist_ok=True)
encoder.save(output_path)
torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt"))
metadata = {
"coord_mean": coord_mean.tolist(),
"coord_std": coord_std.tolist(),
"hidden_dim": args.hidden_dim,
"dropout": args.dropout,
"model_name": args.model_name,
}
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
json.dump(metadata, f, indent=2)
def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
os.makedirs(output_path, exist_ok=True)
encoder.save(os.path.join(output_path, "initial_encoder"))
torch.save(
head.state_dict(),
os.path.join(output_path, "initial_coordinate_head.pt"),
)
metadata = {
"coord_mean": coord_mean.tolist(),
"coord_std": coord_std.tolist(),
"hidden_dim": args.hidden_dim,
"dropout": args.dropout,
"model_name": args.model_name,
}
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
json.dump(metadata, f, indent=2)
def main():
args = parse_args()
set_seed(args.seed)
device = get_device(args.device)
print(f"Using device: {device}")
data = pd.read_csv(args.data_file)
data = data.dropna(subset=["text", "latitude", "longitude"])
texts = data["text"].astype(str).tolist()
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
indices = np.arange(len(data))
if "intersection" in data.columns:
splitter = GroupShuffleSplit(
n_splits=1, test_size=args.test_size, random_state=args.seed
)
train_indices, val_indices = next(
splitter.split(indices, groups=data["intersection"])
)
else:
train_indices, val_indices = train_test_split(
indices, test_size=args.test_size, random_state=args.seed
)
train_dataset = SignCoordinateDataset(
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
)
val_dataset = SignCoordinateDataset(
[texts[i] for i in val_indices], normalized_coordinates[val_indices]
)
encoder = SentenceTransformer(args.model_name, device=str(device))
embedding_dim = encoder.get_sentence_embedding_dimension()
head = CoordinateRegressor(
embedding_dim=embedding_dim,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
).to(device)
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn(encoder, device),
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn(encoder, device),
)
optimizer = torch.optim.AdamW(
[
{"params": encoder.parameters(), "lr": args.learning_rate},
{"params": head.parameters(), "lr": args.head_learning_rate},
],
weight_decay=args.weight_decay,
)
loss_fn = nn.MSELoss()
best_val_loss = float("inf")
print(
f"Training on {len(train_dataset):,} rows; "
f"validating on {len(val_dataset):,} rows"
)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn)
val_loss, val_error_km = evaluate(
encoder, head, val_loader, loss_fn, coord_mean, coord_std
)
print(
f"epoch={epoch} train_loss={train_loss:.6f} "
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"
)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_model(args.output_path, encoder, head, coord_mean, coord_std, args)
print(f"Saved best model to {args.output_path}")
if __name__ == "__main__":
main()