more training data + frozen layers options
This commit is contained in:
parent
bfad4547ce
commit
b5810dd282
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,4 +3,6 @@ plots*
|
||||
*.csv
|
||||
output/
|
||||
.requirements_installed
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
output*
|
||||
|
||||
18
Makefile
18
Makefile
@ -14,6 +14,22 @@ eval: eval.py training.csv
|
||||
@echo "Evaluating coordinate regressor..."
|
||||
@bash -c 'source .venv/bin/activate && python eval.py'
|
||||
|
||||
train_frozen_encoder: train.py training.csv
|
||||
@echo "Training coordinate head with frozen encoder..."
|
||||
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_encoder --freeze-encoder --epochs 50'
|
||||
|
||||
eval_frozen_encoder: eval.py training.csv
|
||||
@echo "Evaluating frozen-encoder coordinate regressor..."
|
||||
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_encoder --output-file predictions_frozen_encoder.csv --plot-file plots/prediction_map_frozen_encoder.png --scatter-plot-file plots/predicted_vs_actual_frozen_encoder.png'
|
||||
|
||||
train_frozen_layers: train.py training.csv
|
||||
@echo "Training coordinate regressor with first transformer layers frozen..."
|
||||
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_layers --freeze-transformer-layers 4'
|
||||
|
||||
eval_frozen_layers: eval.py training.csv
|
||||
@echo "Evaluating frozen-layer coordinate regressor..."
|
||||
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_layers --output-file predictions_frozen_layers.csv --plot-file plots/prediction_map_frozen_layers.png --scatter-plot-file plots/predicted_vs_actual_frozen_layers.png'
|
||||
|
||||
lint:
|
||||
@echo "Auto-linting files and performing final style checks..."
|
||||
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
|
||||
@ -25,4 +41,4 @@ clean:
|
||||
@rm -rf output/
|
||||
@rm -f training.csv predictions.csv
|
||||
|
||||
.PHONY: data train eval lint clean all
|
||||
.PHONY: data train eval train_frozen_encoder eval_frozen_encoder train_frozen_layers eval_frozen_layers lint clean all
|
||||
|
||||
@ -36,7 +36,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--samples-per-intersection",
|
||||
type=int,
|
||||
default=50,
|
||||
default=100,
|
||||
help="Bootstrap samples to create for each intersection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
179
train.py
179
train.py
@ -32,6 +32,18 @@ class SignCoordinateDataset(Dataset):
|
||||
return self.texts[index], self.coordinates[index]
|
||||
|
||||
|
||||
class EmbeddingCoordinateDataset(Dataset):
|
||||
def __init__(self, embeddings, coordinates):
|
||||
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
|
||||
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.embeddings)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.embeddings[index], self.coordinates[index]
|
||||
|
||||
|
||||
class CoordinateRegressor(nn.Module):
|
||||
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
|
||||
super().__init__()
|
||||
@ -66,6 +78,25 @@ def parse_args():
|
||||
parser.add_argument("--test-size", type=float, default=0.2)
|
||||
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||
parser.add_argument("--dropout", type=float, default=0.1)
|
||||
parser.add_argument(
|
||||
"--freeze-encoder",
|
||||
action="store_true",
|
||||
help="Train only the coordinate head; keep the sentence encoder fixed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-transformer-layers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Freeze the first N transformer layers in the sentence encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-attention",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Freeze self-attention parameters while leaving other encoder "
|
||||
"params trainable."
|
||||
),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -102,13 +133,70 @@ def collate_fn(model, device):
|
||||
return collate
|
||||
|
||||
|
||||
def train_epoch(encoder, head, dataloader, optimizer, loss_fn):
|
||||
@torch.no_grad()
|
||||
def encode_texts(encoder, texts, batch_size, device):
|
||||
encoder.eval()
|
||||
embeddings = []
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch = texts[start : start + batch_size]
|
||||
features = encoder.tokenize(batch)
|
||||
features = {key: value.to(device) for key, value in features.items()}
|
||||
batch_embeddings = encoder(features)["sentence_embedding"]
|
||||
embeddings.append(batch_embeddings.cpu().numpy())
|
||||
return np.vstack(embeddings)
|
||||
|
||||
|
||||
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
||||
head.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for embeddings, labels in dataloader:
|
||||
embeddings = embeddings.to(device)
|
||||
labels = labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
predictions = head(embeddings)
|
||||
loss = loss_fn(predictions, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
return total_loss / len(dataloader.dataset)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
|
||||
head.eval()
|
||||
total_loss = 0.0
|
||||
errors_km = []
|
||||
|
||||
for embeddings, labels in dataloader:
|
||||
embeddings = embeddings.to(device)
|
||||
labels = labels.to(device)
|
||||
predictions = head(embeddings)
|
||||
loss = loss_fn(predictions, labels)
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
|
||||
true_coords = labels.cpu().numpy() * coord_std + coord_mean
|
||||
errors_km.extend(haversine_km(pred_coords, true_coords))
|
||||
|
||||
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
||||
|
||||
|
||||
def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable):
|
||||
if encoder_trainable:
|
||||
encoder.train()
|
||||
else:
|
||||
encoder.eval()
|
||||
head.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for features, labels in dataloader:
|
||||
optimizer.zero_grad()
|
||||
if encoder_trainable:
|
||||
embeddings = encoder(features)["sentence_embedding"]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
embeddings = encoder(features)["sentence_embedding"]
|
||||
predictions = head(embeddings)
|
||||
loss = loss_fn(predictions, labels)
|
||||
@ -183,6 +271,43 @@ def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
|
||||
def freeze_encoder_parts(encoder, args):
|
||||
if args.freeze_encoder:
|
||||
for parameter in encoder.parameters():
|
||||
parameter.requires_grad = False
|
||||
return
|
||||
|
||||
transformer = encoder[0].auto_model
|
||||
if args.freeze_transformer_layers > 0:
|
||||
layers = transformer.encoder.layer[: args.freeze_transformer_layers]
|
||||
for layer in layers:
|
||||
for parameter in layer.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
if args.freeze_attention:
|
||||
for name, parameter in transformer.named_parameters():
|
||||
if ".attention." in name or name.startswith("attention."):
|
||||
parameter.requires_grad = False
|
||||
|
||||
|
||||
def count_trainable_parameters(module):
|
||||
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
total = sum(p.numel() for p in module.parameters())
|
||||
return trainable, total
|
||||
|
||||
|
||||
def make_optimizer(encoder, head, args):
|
||||
parameter_groups = []
|
||||
encoder_parameters = [p for p in encoder.parameters() if p.requires_grad]
|
||||
if encoder_parameters:
|
||||
group = {"params": encoder_parameters, "lr": args.learning_rate}
|
||||
parameter_groups.append(group)
|
||||
parameter_groups.append(
|
||||
{"params": head.parameters(), "lr": args.head_learning_rate}
|
||||
)
|
||||
return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
@ -223,7 +348,34 @@ def main():
|
||||
dropout=args.dropout,
|
||||
).to(device)
|
||||
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
|
||||
freeze_encoder_parts(encoder, args)
|
||||
encoder_trainable, encoder_total = count_trainable_parameters(encoder)
|
||||
head_trainable, head_total = count_trainable_parameters(head)
|
||||
print(
|
||||
f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; "
|
||||
f"head params: {head_trainable:,}/{head_total:,}"
|
||||
)
|
||||
|
||||
if encoder_trainable == 0:
|
||||
print("Caching frozen encoder embeddings...")
|
||||
all_embeddings = encode_texts(encoder, texts, args.batch_size, device)
|
||||
train_dataset = EmbeddingCoordinateDataset(
|
||||
all_embeddings[train_indices], normalized_coordinates[train_indices]
|
||||
)
|
||||
val_dataset = EmbeddingCoordinateDataset(
|
||||
all_embeddings[val_indices], normalized_coordinates[val_indices]
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
else:
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
@ -237,13 +389,7 @@ def main():
|
||||
collate_fn=collate_fn(encoder, device),
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[
|
||||
{"params": encoder.parameters(), "lr": args.learning_rate},
|
||||
{"params": head.parameters(), "lr": args.head_learning_rate},
|
||||
],
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
optimizer = make_optimizer(encoder, head, args)
|
||||
loss_fn = nn.MSELoss()
|
||||
best_val_loss = float("inf")
|
||||
|
||||
@ -252,7 +398,22 @@ def main():
|
||||
f"validating on {len(val_dataset):,} rows"
|
||||
)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn)
|
||||
if encoder_trainable == 0:
|
||||
train_loss = train_head_epoch(
|
||||
head, train_loader, optimizer, loss_fn, device
|
||||
)
|
||||
val_loss, val_error_km = evaluate_head(
|
||||
head, val_loader, loss_fn, coord_mean, coord_std, device
|
||||
)
|
||||
else:
|
||||
train_loss = train_epoch(
|
||||
encoder,
|
||||
head,
|
||||
train_loader,
|
||||
optimizer,
|
||||
loss_fn,
|
||||
True,
|
||||
)
|
||||
val_loss, val_error_km = evaluate(
|
||||
encoder, head, val_loader, loss_fn, coord_mean, coord_std
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user