From b5810dd2824d90b84ef43bdd89bfdbb7803bd729 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 25 May 2026 15:16:19 -0600 Subject: [PATCH] more training data + frozen layers options --- .gitignore | 2 + Makefile | 18 +++- prepare_training_data.py | 2 +- train.py | 213 ++++++++++++++++++++++++++++++++++----- 4 files changed, 207 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 033ef8b..03b6b1b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ plots* *.csv output/ .requirements_installed +.DS_Store __pycache__/ +output* diff --git a/Makefile b/Makefile index fe64108..c0f51c7 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,22 @@ eval: eval.py training.csv @echo "Evaluating coordinate regressor..." @bash -c 'source .venv/bin/activate && python eval.py' +train_frozen_encoder: train.py training.csv + @echo "Training coordinate head with frozen encoder..." + @bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_encoder --freeze-encoder --epochs 50' + +eval_frozen_encoder: eval.py training.csv + @echo "Evaluating frozen-encoder coordinate regressor..." + @bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_encoder --output-file predictions_frozen_encoder.csv --plot-file plots/prediction_map_frozen_encoder.png --scatter-plot-file plots/predicted_vs_actual_frozen_encoder.png' + +train_frozen_layers: train.py training.csv + @echo "Training coordinate regressor with first transformer layers frozen..." + @bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_layers --freeze-transformer-layers 4' + +eval_frozen_layers: eval.py training.csv + @echo "Evaluating frozen-layer coordinate regressor..." + @bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_layers --output-file predictions_frozen_layers.csv --plot-file plots/prediction_map_frozen_layers.png --scatter-plot-file plots/predicted_vs_actual_frozen_layers.png' + lint: @echo "Auto-linting files and performing final style checks..." @bash -c 'source .venv/bin/activate && isort --profile=black *.py' @@ -25,4 +41,4 @@ clean: @rm -rf output/ @rm -f training.csv predictions.csv -.PHONY: data train eval lint clean all +.PHONY: data train eval train_frozen_encoder eval_frozen_encoder train_frozen_layers eval_frozen_layers lint clean all diff --git a/prepare_training_data.py b/prepare_training_data.py index ad19c33..6cbea77 100644 --- a/prepare_training_data.py +++ b/prepare_training_data.py @@ -36,7 +36,7 @@ def parse_args(): parser.add_argument( "--samples-per-intersection", type=int, - default=50, + default=100, help="Bootstrap samples to create for each intersection.", ) parser.add_argument( diff --git a/train.py b/train.py index 27d730c..e56689f 100644 --- a/train.py +++ b/train.py @@ -32,6 +32,18 @@ class SignCoordinateDataset(Dataset): return self.texts[index], self.coordinates[index] +class EmbeddingCoordinateDataset(Dataset): + def __init__(self, embeddings, coordinates): + self.embeddings = torch.tensor(embeddings, dtype=torch.float32) + self.coordinates = torch.tensor(coordinates, dtype=torch.float32) + + def __len__(self): + return len(self.embeddings) + + def __getitem__(self, index): + return self.embeddings[index], self.coordinates[index] + + class CoordinateRegressor(nn.Module): def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1): super().__init__() @@ -66,6 +78,25 @@ def parse_args(): parser.add_argument("--test-size", type=float, default=0.2) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument( + "--freeze-encoder", + action="store_true", + help="Train only the coordinate head; keep the sentence encoder fixed.", + ) + parser.add_argument( + "--freeze-transformer-layers", + type=int, + default=0, + help="Freeze the first N transformer layers in the sentence encoder.", + ) + parser.add_argument( + "--freeze-attention", + action="store_true", + help=( + "Freeze self-attention parameters while leaving other encoder " + "params trainable." + ), + ) return parser.parse_args() @@ -102,14 +133,71 @@ def collate_fn(model, device): return collate -def train_epoch(encoder, head, dataloader, optimizer, loss_fn): - encoder.train() +@torch.no_grad() +def encode_texts(encoder, texts, batch_size, device): + encoder.eval() + embeddings = [] + for start in range(0, len(texts), batch_size): + batch = texts[start : start + batch_size] + features = encoder.tokenize(batch) + features = {key: value.to(device) for key, value in features.items()} + batch_embeddings = encoder(features)["sentence_embedding"] + embeddings.append(batch_embeddings.cpu().numpy()) + return np.vstack(embeddings) + + +def train_head_epoch(head, dataloader, optimizer, loss_fn, device): + head.train() + total_loss = 0.0 + + for embeddings, labels in dataloader: + embeddings = embeddings.to(device) + labels = labels.to(device) + optimizer.zero_grad() + predictions = head(embeddings) + loss = loss_fn(predictions, labels) + loss.backward() + optimizer.step() + total_loss += loss.item() * labels.size(0) + + return total_loss / len(dataloader.dataset) + + +@torch.no_grad() +def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device): + head.eval() + total_loss = 0.0 + errors_km = [] + + for embeddings, labels in dataloader: + embeddings = embeddings.to(device) + labels = labels.to(device) + predictions = head(embeddings) + loss = loss_fn(predictions, labels) + total_loss += loss.item() * labels.size(0) + + pred_coords = predictions.cpu().numpy() * coord_std + coord_mean + true_coords = labels.cpu().numpy() * coord_std + coord_mean + errors_km.extend(haversine_km(pred_coords, true_coords)) + + return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) + + +def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable): + if encoder_trainable: + encoder.train() + else: + encoder.eval() head.train() total_loss = 0.0 for features, labels in dataloader: optimizer.zero_grad() - embeddings = encoder(features)["sentence_embedding"] + if encoder_trainable: + embeddings = encoder(features)["sentence_embedding"] + else: + with torch.no_grad(): + embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) loss.backward() @@ -183,6 +271,43 @@ def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args): json.dump(metadata, f, indent=2) +def freeze_encoder_parts(encoder, args): + if args.freeze_encoder: + for parameter in encoder.parameters(): + parameter.requires_grad = False + return + + transformer = encoder[0].auto_model + if args.freeze_transformer_layers > 0: + layers = transformer.encoder.layer[: args.freeze_transformer_layers] + for layer in layers: + for parameter in layer.parameters(): + parameter.requires_grad = False + + if args.freeze_attention: + for name, parameter in transformer.named_parameters(): + if ".attention." in name or name.startswith("attention."): + parameter.requires_grad = False + + +def count_trainable_parameters(module): + trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) + total = sum(p.numel() for p in module.parameters()) + return trainable, total + + +def make_optimizer(encoder, head, args): + parameter_groups = [] + encoder_parameters = [p for p in encoder.parameters() if p.requires_grad] + if encoder_parameters: + group = {"params": encoder_parameters, "lr": args.learning_rate} + parameter_groups.append(group) + parameter_groups.append( + {"params": head.parameters(), "lr": args.head_learning_rate} + ) + return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay) + + def main(): args = parse_args() set_seed(args.seed) @@ -223,27 +348,48 @@ def main(): dropout=args.dropout, ).to(device) save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args) - - train_loader = DataLoader( - train_dataset, - batch_size=args.batch_size, - shuffle=True, - collate_fn=collate_fn(encoder, device), - ) - val_loader = DataLoader( - val_dataset, - batch_size=args.batch_size, - shuffle=False, - collate_fn=collate_fn(encoder, device), + freeze_encoder_parts(encoder, args) + encoder_trainable, encoder_total = count_trainable_parameters(encoder) + head_trainable, head_total = count_trainable_parameters(head) + print( + f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; " + f"head params: {head_trainable:,}/{head_total:,}" ) - optimizer = torch.optim.AdamW( - [ - {"params": encoder.parameters(), "lr": args.learning_rate}, - {"params": head.parameters(), "lr": args.head_learning_rate}, - ], - weight_decay=args.weight_decay, - ) + if encoder_trainable == 0: + print("Caching frozen encoder embeddings...") + all_embeddings = encode_texts(encoder, texts, args.batch_size, device) + train_dataset = EmbeddingCoordinateDataset( + all_embeddings[train_indices], normalized_coordinates[train_indices] + ) + val_dataset = EmbeddingCoordinateDataset( + all_embeddings[val_indices], normalized_coordinates[val_indices] + ) + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + ) + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + ) + else: + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=collate_fn(encoder, device), + ) + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=collate_fn(encoder, device), + ) + + optimizer = make_optimizer(encoder, head, args) loss_fn = nn.MSELoss() best_val_loss = float("inf") @@ -252,10 +398,25 @@ def main(): f"validating on {len(val_dataset):,} rows" ) for epoch in range(1, args.epochs + 1): - train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn) - val_loss, val_error_km = evaluate( - encoder, head, val_loader, loss_fn, coord_mean, coord_std - ) + if encoder_trainable == 0: + train_loss = train_head_epoch( + head, train_loader, optimizer, loss_fn, device + ) + val_loss, val_error_km = evaluate_head( + head, val_loader, loss_fn, coord_mean, coord_std, device + ) + else: + train_loss = train_epoch( + encoder, + head, + train_loader, + optimizer, + loss_fn, + True, + ) + val_loss, val_error_km = evaluate( + encoder, head, val_loader, loss_fn, coord_mean, coord_std + ) print( f"epoch={epoch} train_loss={train_loss:.6f} " f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"