more training data + frozen layers options

This commit is contained in:
Michael Pilosov 2026-05-25 15:16:19 -06:00
parent bfad4547ce
commit b5810dd282
4 changed files with 207 additions and 28 deletions

2
.gitignore vendored
View File

@ -3,4 +3,6 @@ plots*
*.csv
output/
.requirements_installed
.DS_Store
__pycache__/
output*

View File

@ -14,6 +14,22 @@ eval: eval.py training.csv
@echo "Evaluating coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py'
train_frozen_encoder: train.py training.csv
@echo "Training coordinate head with frozen encoder..."
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_encoder --freeze-encoder --epochs 50'
eval_frozen_encoder: eval.py training.csv
@echo "Evaluating frozen-encoder coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_encoder --output-file predictions_frozen_encoder.csv --plot-file plots/prediction_map_frozen_encoder.png --scatter-plot-file plots/predicted_vs_actual_frozen_encoder.png'
train_frozen_layers: train.py training.csv
@echo "Training coordinate regressor with first transformer layers frozen..."
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_layers --freeze-transformer-layers 4'
eval_frozen_layers: eval.py training.csv
@echo "Evaluating frozen-layer coordinate regressor..."
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_layers --output-file predictions_frozen_layers.csv --plot-file plots/prediction_map_frozen_layers.png --scatter-plot-file plots/predicted_vs_actual_frozen_layers.png'
lint:
@echo "Auto-linting files and performing final style checks..."
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
@ -25,4 +41,4 @@ clean:
@rm -rf output/
@rm -f training.csv predictions.csv
.PHONY: data train eval lint clean all
.PHONY: data train eval train_frozen_encoder eval_frozen_encoder train_frozen_layers eval_frozen_layers lint clean all

View File

@ -36,7 +36,7 @@ def parse_args():
parser.add_argument(
"--samples-per-intersection",
type=int,
default=50,
default=100,
help="Bootstrap samples to create for each intersection.",
)
parser.add_argument(

213
train.py
View File

@ -32,6 +32,18 @@ class SignCoordinateDataset(Dataset):
return self.texts[index], self.coordinates[index]
class EmbeddingCoordinateDataset(Dataset):
def __init__(self, embeddings, coordinates):
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
def __len__(self):
return len(self.embeddings)
def __getitem__(self, index):
return self.embeddings[index], self.coordinates[index]
class CoordinateRegressor(nn.Module):
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
super().__init__()
@ -66,6 +78,25 @@ def parse_args():
parser.add_argument("--test-size", type=float, default=0.2)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument(
"--freeze-encoder",
action="store_true",
help="Train only the coordinate head; keep the sentence encoder fixed.",
)
parser.add_argument(
"--freeze-transformer-layers",
type=int,
default=0,
help="Freeze the first N transformer layers in the sentence encoder.",
)
parser.add_argument(
"--freeze-attention",
action="store_true",
help=(
"Freeze self-attention parameters while leaving other encoder "
"params trainable."
),
)
return parser.parse_args()
@ -102,14 +133,71 @@ def collate_fn(model, device):
return collate
def train_epoch(encoder, head, dataloader, optimizer, loss_fn):
encoder.train()
@torch.no_grad()
def encode_texts(encoder, texts, batch_size, device):
encoder.eval()
embeddings = []
for start in range(0, len(texts), batch_size):
batch = texts[start : start + batch_size]
features = encoder.tokenize(batch)
features = {key: value.to(device) for key, value in features.items()}
batch_embeddings = encoder(features)["sentence_embedding"]
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings)
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
head.train()
total_loss = 0.0
for embeddings, labels in dataloader:
embeddings = embeddings.to(device)
labels = labels.to(device)
optimizer.zero_grad()
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
return total_loss / len(dataloader.dataset)
@torch.no_grad()
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
head.eval()
total_loss = 0.0
errors_km = []
for embeddings, labels in dataloader:
embeddings = embeddings.to(device)
labels = labels.to(device)
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0)
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
true_coords = labels.cpu().numpy() * coord_std + coord_mean
errors_km.extend(haversine_km(pred_coords, true_coords))
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable):
if encoder_trainable:
encoder.train()
else:
encoder.eval()
head.train()
total_loss = 0.0
for features, labels in dataloader:
optimizer.zero_grad()
embeddings = encoder(features)["sentence_embedding"]
if encoder_trainable:
embeddings = encoder(features)["sentence_embedding"]
else:
with torch.no_grad():
embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings)
loss = loss_fn(predictions, labels)
loss.backward()
@ -183,6 +271,43 @@ def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
json.dump(metadata, f, indent=2)
def freeze_encoder_parts(encoder, args):
if args.freeze_encoder:
for parameter in encoder.parameters():
parameter.requires_grad = False
return
transformer = encoder[0].auto_model
if args.freeze_transformer_layers > 0:
layers = transformer.encoder.layer[: args.freeze_transformer_layers]
for layer in layers:
for parameter in layer.parameters():
parameter.requires_grad = False
if args.freeze_attention:
for name, parameter in transformer.named_parameters():
if ".attention." in name or name.startswith("attention."):
parameter.requires_grad = False
def count_trainable_parameters(module):
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
total = sum(p.numel() for p in module.parameters())
return trainable, total
def make_optimizer(encoder, head, args):
parameter_groups = []
encoder_parameters = [p for p in encoder.parameters() if p.requires_grad]
if encoder_parameters:
group = {"params": encoder_parameters, "lr": args.learning_rate}
parameter_groups.append(group)
parameter_groups.append(
{"params": head.parameters(), "lr": args.head_learning_rate}
)
return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay)
def main():
args = parse_args()
set_seed(args.seed)
@ -223,27 +348,48 @@ def main():
dropout=args.dropout,
).to(device)
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn(encoder, device),
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn(encoder, device),
freeze_encoder_parts(encoder, args)
encoder_trainable, encoder_total = count_trainable_parameters(encoder)
head_trainable, head_total = count_trainable_parameters(head)
print(
f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; "
f"head params: {head_trainable:,}/{head_total:,}"
)
optimizer = torch.optim.AdamW(
[
{"params": encoder.parameters(), "lr": args.learning_rate},
{"params": head.parameters(), "lr": args.head_learning_rate},
],
weight_decay=args.weight_decay,
)
if encoder_trainable == 0:
print("Caching frozen encoder embeddings...")
all_embeddings = encode_texts(encoder, texts, args.batch_size, device)
train_dataset = EmbeddingCoordinateDataset(
all_embeddings[train_indices], normalized_coordinates[train_indices]
)
val_dataset = EmbeddingCoordinateDataset(
all_embeddings[val_indices], normalized_coordinates[val_indices]
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
)
else:
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn(encoder, device),
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn(encoder, device),
)
optimizer = make_optimizer(encoder, head, args)
loss_fn = nn.MSELoss()
best_val_loss = float("inf")
@ -252,10 +398,25 @@ def main():
f"validating on {len(val_dataset):,} rows"
)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn)
val_loss, val_error_km = evaluate(
encoder, head, val_loader, loss_fn, coord_mean, coord_std
)
if encoder_trainable == 0:
train_loss = train_head_epoch(
head, train_loader, optimizer, loss_fn, device
)
val_loss, val_error_km = evaluate_head(
head, val_loader, loss_fn, coord_mean, coord_std, device
)
else:
train_loss = train_epoch(
encoder,
head,
train_loader,
optimizer,
loss_fn,
True,
)
val_loss, val_error_km = evaluate(
encoder, head, val_loader, loss_fn, coord_mean, coord_std
)
print(
f"epoch={epoch} train_loss={train_loss:.6f} "
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"