- prepare_training_data: bag_size 5→8, street signs fill slots first so every sample contains the most geographically discriminative texts - train: HuberLoss replaces MSE (robust to outlier intersections), ReduceLROnPlateau scheduler added, split logic extracted to data_utils - eval: reproduce train/val split to report honest per-bag and per-intersection-aggregated metrics separately for train and val sets - data_utils: shared split_indices() so train and eval use identical splits Co-authored-by: Cursor <cursoragent@cursor.com>
586 lines
19 KiB
Python
586 lines
19 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from sentence_transformers import LoggingHandler, SentenceTransformer
|
|
from data_utils import split_indices
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
handlers=[LoggingHandler()],
|
|
)
|
|
|
|
|
|
class SignCoordinateDataset(Dataset):
|
|
def __init__(self, texts, coordinates):
|
|
self.texts = list(texts)
|
|
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, index):
|
|
return self.texts[index], self.coordinates[index]
|
|
|
|
|
|
class EmbeddingCoordinateDataset(Dataset):
|
|
def __init__(self, embeddings, coordinates):
|
|
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
|
|
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
|
|
|
|
def __len__(self):
|
|
return len(self.embeddings)
|
|
|
|
def __getitem__(self, index):
|
|
return self.embeddings[index], self.coordinates[index]
|
|
|
|
|
|
class CoordinateRegressor(nn.Module):
|
|
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
|
|
super().__init__()
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(embedding_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim // 2, 2),
|
|
)
|
|
|
|
def forward(self, embeddings):
|
|
return self.layers(embeddings)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Train sign text to lat/lon model.")
|
|
parser.add_argument("--data-file", default="training.csv")
|
|
parser.add_argument("--output-path", default="output")
|
|
parser.add_argument(
|
|
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
default="cuda",
|
|
help="Device to use for training. Defaults to `cuda`.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=1992)
|
|
parser.add_argument("--epochs", type=int, default=10)
|
|
parser.add_argument("--batch-size", type=int, default=64)
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=2,
|
|
help="DataLoader workers to prefetch batches while the GPU trains.",
|
|
)
|
|
parser.add_argument(
|
|
"--save-every-epochs",
|
|
type=int,
|
|
default=5,
|
|
help=(
|
|
"Write the best checkpoint to disk every N epochs (and at the final "
|
|
"epoch). Validation still runs every epoch."
|
|
),
|
|
)
|
|
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
|
parser.add_argument("--head-learning-rate", type=float, default=5e-2)
|
|
parser.add_argument("--weight-decay", type=float, default=0.001)
|
|
parser.add_argument("--test-size", type=float, default=0.2)
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
|
parser.add_argument("--dropout", type=float, default=0.1)
|
|
parser.add_argument(
|
|
"--freeze-encoder",
|
|
action="store_true",
|
|
help="Train only the coordinate head; keep the sentence encoder fixed.",
|
|
)
|
|
parser.add_argument(
|
|
"--freeze-transformer-layers",
|
|
type=int,
|
|
default=0,
|
|
help="Freeze the first N transformer layers in the sentence encoder.",
|
|
)
|
|
parser.add_argument(
|
|
"--freeze-attention",
|
|
action="store_true",
|
|
help=(
|
|
"Freeze self-attention parameters while leaving other encoder "
|
|
"params trainable."
|
|
),
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_device(requested_device):
|
|
if requested_device:
|
|
return torch.device(requested_device)
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
return torch.device("cpu")
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def normalize_coordinates(coordinates):
|
|
mean = coordinates.mean(axis=0)
|
|
std = coordinates.std(axis=0)
|
|
std[std == 0] = 1.0
|
|
return (coordinates - mean) / std, mean, std
|
|
|
|
|
|
def move_features_to_device(features, device, non_blocking=False):
|
|
return {
|
|
key: value.to(device, non_blocking=non_blocking)
|
|
for key, value in features.items()
|
|
}
|
|
|
|
|
|
def move_batch_to_device(features, labels, device, pin_memory=False):
|
|
non_blocking = pin_memory and device.type == "cuda"
|
|
if pin_memory and device.type == "cuda":
|
|
features = {key: value.pin_memory() for key, value in features.items()}
|
|
labels = labels.pin_memory()
|
|
return (
|
|
move_features_to_device(features, device, non_blocking=non_blocking),
|
|
labels.to(device, non_blocking=non_blocking),
|
|
)
|
|
|
|
|
|
def move_tensors_to_device(tensors, device, pin_memory=False):
|
|
non_blocking = pin_memory and device.type == "cuda"
|
|
if pin_memory and device.type == "cuda":
|
|
tensors = [tensor.pin_memory() for tensor in tensors]
|
|
return [tensor.to(device, non_blocking=non_blocking) for tensor in tensors]
|
|
|
|
|
|
def make_text_collate(tokenize):
|
|
def collate(batch):
|
|
texts, labels = zip(*batch)
|
|
features = tokenize(list(texts))
|
|
return features, torch.stack(labels)
|
|
|
|
return collate
|
|
|
|
|
|
def embedding_collate(batch):
|
|
embeddings, labels = zip(*batch)
|
|
return torch.stack(embeddings), torch.stack(labels)
|
|
|
|
|
|
def make_dataloader(dataset, batch_size, shuffle, collate_fn, num_workers, pin_memory):
|
|
loader_kwargs = {
|
|
"dataset": dataset,
|
|
"batch_size": batch_size,
|
|
"shuffle": shuffle,
|
|
"collate_fn": collate_fn,
|
|
"num_workers": num_workers,
|
|
"pin_memory": pin_memory,
|
|
}
|
|
if num_workers > 0:
|
|
loader_kwargs["persistent_workers"] = True
|
|
return DataLoader(**loader_kwargs)
|
|
|
|
|
|
def copy_module_state(module):
|
|
return {key: value.detach().cpu() for key, value in module.state_dict().items()}
|
|
|
|
|
|
def save_best_checkpoint(
|
|
output_path, encoder, head, best_states, coord_mean, coord_std, args
|
|
):
|
|
encoder_state = encoder.state_dict()
|
|
head_state = head.state_dict()
|
|
encoder.load_state_dict(best_states["encoder"])
|
|
head.load_state_dict(best_states["head"])
|
|
save_model(output_path, encoder, head, coord_mean, coord_std, args)
|
|
encoder.load_state_dict(encoder_state)
|
|
head.load_state_dict(head_state)
|
|
|
|
|
|
def should_save_checkpoint(epoch, total_epochs, save_every_epochs, pending_save):
|
|
if not pending_save:
|
|
return False
|
|
if epoch == total_epochs:
|
|
return True
|
|
return epoch % save_every_epochs == 0
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode_texts(encoder, texts, batch_size, device):
|
|
encoder.eval()
|
|
embeddings = []
|
|
for start in range(0, len(texts), batch_size):
|
|
batch = texts[start : start + batch_size]
|
|
features = encoder.tokenize(batch)
|
|
features = {key: value.to(device) for key, value in features.items()}
|
|
batch_embeddings = encoder(features)["sentence_embedding"]
|
|
embeddings.append(batch_embeddings.cpu().numpy())
|
|
return np.vstack(embeddings)
|
|
|
|
|
|
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
|
head.train()
|
|
total_loss = 0.0
|
|
pin_memory = dataloader.pin_memory
|
|
|
|
for embeddings, labels in dataloader:
|
|
embeddings, labels = move_tensors_to_device(
|
|
[embeddings, labels], device, pin_memory=pin_memory
|
|
)
|
|
optimizer.zero_grad(set_to_none=True)
|
|
predictions = head(embeddings)
|
|
loss = loss_fn(predictions, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item() * labels.size(0)
|
|
|
|
return total_loss / len(dataloader.dataset)
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
|
|
head.eval()
|
|
total_loss = 0.0
|
|
predictions_all = []
|
|
labels_all = []
|
|
pin_memory = dataloader.pin_memory
|
|
|
|
for embeddings, labels in dataloader:
|
|
embeddings, labels = move_tensors_to_device(
|
|
[embeddings, labels], device, pin_memory=pin_memory
|
|
)
|
|
predictions = head(embeddings)
|
|
loss = loss_fn(predictions, labels)
|
|
total_loss += loss.item() * labels.size(0)
|
|
predictions_all.append(predictions)
|
|
labels_all.append(labels)
|
|
|
|
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
|
|
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
|
|
errors_km = haversine_km(pred_coords, true_coords)
|
|
|
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
|
|
|
|
|
def train_epoch(
|
|
encoder,
|
|
head,
|
|
dataloader,
|
|
optimizer,
|
|
loss_fn,
|
|
device,
|
|
encoder_trainable,
|
|
):
|
|
if encoder_trainable:
|
|
encoder.train()
|
|
else:
|
|
encoder.eval()
|
|
head.train()
|
|
total_loss = 0.0
|
|
pin_memory = dataloader.pin_memory
|
|
|
|
for features, labels in dataloader:
|
|
features, labels = move_batch_to_device(
|
|
features, labels, device, pin_memory=pin_memory
|
|
)
|
|
optimizer.zero_grad(set_to_none=True)
|
|
if encoder_trainable:
|
|
embeddings = encoder(features)["sentence_embedding"]
|
|
else:
|
|
with torch.no_grad():
|
|
embeddings = encoder(features)["sentence_embedding"]
|
|
predictions = head(embeddings)
|
|
loss = loss_fn(predictions, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item() * labels.size(0)
|
|
|
|
return total_loss / len(dataloader.dataset)
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std, device):
|
|
encoder.eval()
|
|
head.eval()
|
|
total_loss = 0.0
|
|
predictions_all = []
|
|
labels_all = []
|
|
pin_memory = dataloader.pin_memory
|
|
|
|
for features, labels in dataloader:
|
|
features, labels = move_batch_to_device(
|
|
features, labels, device, pin_memory=pin_memory
|
|
)
|
|
embeddings = encoder(features)["sentence_embedding"]
|
|
predictions = head(embeddings)
|
|
loss = loss_fn(predictions, labels)
|
|
total_loss += loss.item() * labels.size(0)
|
|
predictions_all.append(predictions)
|
|
labels_all.append(labels)
|
|
|
|
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
|
|
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
|
|
errors_km = haversine_km(pred_coords, true_coords)
|
|
|
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
|
|
|
|
|
def haversine_km(pred_coords, true_coords):
|
|
lat1 = np.radians(pred_coords[:, 0])
|
|
lon1 = np.radians(pred_coords[:, 1])
|
|
lat2 = np.radians(true_coords[:, 0])
|
|
lon2 = np.radians(true_coords[:, 1])
|
|
dlat = lat2 - lat1
|
|
dlon = lon2 - lon1
|
|
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
|
|
return 2 * 6371.0088 * np.arcsin(np.sqrt(a))
|
|
|
|
|
|
def save_model(output_path, encoder, head, coord_mean, coord_std, args):
|
|
os.makedirs(output_path, exist_ok=True)
|
|
encoder.save(output_path)
|
|
torch.save(head.state_dict(), os.path.join(output_path, "coordinate_head.pt"))
|
|
metadata = {
|
|
"coord_mean": coord_mean.tolist(),
|
|
"coord_std": coord_std.tolist(),
|
|
"hidden_dim": args.hidden_dim,
|
|
"dropout": args.dropout,
|
|
"model_name": args.model_name,
|
|
}
|
|
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
|
|
def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
|
|
os.makedirs(output_path, exist_ok=True)
|
|
encoder.save(os.path.join(output_path, "initial_encoder"))
|
|
torch.save(
|
|
head.state_dict(),
|
|
os.path.join(output_path, "initial_coordinate_head.pt"),
|
|
)
|
|
metadata = {
|
|
"coord_mean": coord_mean.tolist(),
|
|
"coord_std": coord_std.tolist(),
|
|
"hidden_dim": args.hidden_dim,
|
|
"dropout": args.dropout,
|
|
"model_name": args.model_name,
|
|
}
|
|
with open(os.path.join(output_path, "coordinate_config.json"), "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
|
|
def freeze_encoder_parts(encoder, args):
|
|
if args.freeze_encoder:
|
|
for parameter in encoder.parameters():
|
|
parameter.requires_grad = False
|
|
return
|
|
|
|
transformer = encoder[0].auto_model
|
|
if args.freeze_transformer_layers > 0:
|
|
layers = transformer.encoder.layer[: args.freeze_transformer_layers]
|
|
for layer in layers:
|
|
for parameter in layer.parameters():
|
|
parameter.requires_grad = False
|
|
|
|
if args.freeze_attention:
|
|
for name, parameter in transformer.named_parameters():
|
|
if ".attention." in name or name.startswith("attention."):
|
|
parameter.requires_grad = False
|
|
|
|
|
|
def count_trainable_parameters(module):
|
|
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
|
total = sum(p.numel() for p in module.parameters())
|
|
return trainable, total
|
|
|
|
|
|
def make_optimizer(encoder, head, args):
|
|
parameter_groups = []
|
|
encoder_parameters = [p for p in encoder.parameters() if p.requires_grad]
|
|
if encoder_parameters:
|
|
group = {"params": encoder_parameters, "lr": args.learning_rate}
|
|
parameter_groups.append(group)
|
|
parameter_groups.append(
|
|
{"params": head.parameters(), "lr": args.head_learning_rate}
|
|
)
|
|
return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
set_seed(args.seed)
|
|
device = get_device(args.device)
|
|
pin_memory = device.type == "cuda"
|
|
print(f"Using device: {device}")
|
|
|
|
data = pd.read_csv(args.data_file)
|
|
data = data.dropna(subset=["text", "latitude", "longitude"])
|
|
texts = data["text"].astype(str).tolist()
|
|
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
|
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
|
|
|
|
train_indices, val_indices = split_indices(
|
|
data, test_size=args.test_size, seed=args.seed
|
|
)
|
|
|
|
train_dataset = SignCoordinateDataset(
|
|
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
|
|
)
|
|
val_dataset = SignCoordinateDataset(
|
|
[texts[i] for i in val_indices], normalized_coordinates[val_indices]
|
|
)
|
|
|
|
encoder = SentenceTransformer(args.model_name, device=str(device))
|
|
encoder.to(device)
|
|
embedding_dim = encoder.get_sentence_embedding_dimension()
|
|
head = CoordinateRegressor(
|
|
embedding_dim=embedding_dim,
|
|
hidden_dim=args.hidden_dim,
|
|
dropout=args.dropout,
|
|
).to(device)
|
|
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
|
|
freeze_encoder_parts(encoder, args)
|
|
encoder_trainable, encoder_total = count_trainable_parameters(encoder)
|
|
head_trainable, head_total = count_trainable_parameters(head)
|
|
print(
|
|
f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; "
|
|
f"head params: {head_trainable:,}/{head_total:,}"
|
|
)
|
|
|
|
if encoder_trainable == 0:
|
|
print("Caching frozen encoder embeddings...")
|
|
all_embeddings = encode_texts(encoder, texts, args.batch_size, device)
|
|
train_dataset = EmbeddingCoordinateDataset(
|
|
all_embeddings[train_indices], normalized_coordinates[train_indices]
|
|
)
|
|
val_dataset = EmbeddingCoordinateDataset(
|
|
all_embeddings[val_indices], normalized_coordinates[val_indices]
|
|
)
|
|
train_loader = make_dataloader(
|
|
train_dataset,
|
|
args.batch_size,
|
|
shuffle=True,
|
|
collate_fn=embedding_collate,
|
|
num_workers=args.num_workers,
|
|
pin_memory=pin_memory,
|
|
)
|
|
val_loader = make_dataloader(
|
|
val_dataset,
|
|
args.batch_size,
|
|
shuffle=False,
|
|
collate_fn=embedding_collate,
|
|
num_workers=args.num_workers,
|
|
pin_memory=pin_memory,
|
|
)
|
|
else:
|
|
text_collate = make_text_collate(encoder.tokenize)
|
|
train_loader = make_dataloader(
|
|
train_dataset,
|
|
args.batch_size,
|
|
shuffle=True,
|
|
collate_fn=text_collate,
|
|
num_workers=args.num_workers,
|
|
pin_memory=pin_memory,
|
|
)
|
|
val_loader = make_dataloader(
|
|
val_dataset,
|
|
args.batch_size,
|
|
shuffle=False,
|
|
collate_fn=text_collate,
|
|
num_workers=args.num_workers,
|
|
pin_memory=pin_memory,
|
|
)
|
|
|
|
optimizer = make_optimizer(encoder, head, args)
|
|
loss_fn = nn.HuberLoss()
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7
|
|
)
|
|
best_val_loss = float("inf")
|
|
best_states = None
|
|
pending_save = False
|
|
|
|
print(
|
|
f"Training on {len(train_dataset):,} rows; "
|
|
f"validating on {len(val_dataset):,} rows; "
|
|
f"batch_size={args.batch_size}; num_workers={args.num_workers}"
|
|
)
|
|
for epoch in range(1, args.epochs + 1):
|
|
if encoder_trainable == 0:
|
|
train_loss = train_head_epoch(
|
|
head, train_loader, optimizer, loss_fn, device
|
|
)
|
|
val_loss, val_error_km = evaluate_head(
|
|
head, val_loader, loss_fn, coord_mean, coord_std, device
|
|
)
|
|
else:
|
|
train_loss = train_epoch(
|
|
encoder,
|
|
head,
|
|
train_loader,
|
|
optimizer,
|
|
loss_fn,
|
|
device,
|
|
encoder_trainable > 0,
|
|
)
|
|
val_loss, val_error_km = evaluate(
|
|
encoder,
|
|
head,
|
|
val_loader,
|
|
loss_fn,
|
|
coord_mean,
|
|
coord_std,
|
|
device,
|
|
)
|
|
scheduler.step(val_loss)
|
|
current_lr = optimizer.param_groups[-1]["lr"]
|
|
print(
|
|
f"epoch={epoch} train_loss={train_loss:.6f} "
|
|
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} "
|
|
f"lr={current_lr:.2e}"
|
|
)
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
best_states = {
|
|
"encoder": copy_module_state(encoder),
|
|
"head": copy_module_state(head),
|
|
}
|
|
pending_save = True
|
|
|
|
if should_save_checkpoint(
|
|
epoch, args.epochs, args.save_every_epochs, pending_save
|
|
):
|
|
save_best_checkpoint(
|
|
args.output_path,
|
|
encoder,
|
|
head,
|
|
best_states,
|
|
coord_mean,
|
|
coord_std,
|
|
args,
|
|
)
|
|
pending_save = False
|
|
print(
|
|
f"Saved best model to {args.output_path} "
|
|
f"(val_loss={best_val_loss:.6f})"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|