more training data + frozen layers options
This commit is contained in:
parent
bfad4547ce
commit
b5810dd282
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,4 +3,6 @@ plots*
|
|||||||
*.csv
|
*.csv
|
||||||
output/
|
output/
|
||||||
.requirements_installed
|
.requirements_installed
|
||||||
|
.DS_Store
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
output*
|
||||||
|
|||||||
18
Makefile
18
Makefile
@ -14,6 +14,22 @@ eval: eval.py training.csv
|
|||||||
@echo "Evaluating coordinate regressor..."
|
@echo "Evaluating coordinate regressor..."
|
||||||
@bash -c 'source .venv/bin/activate && python eval.py'
|
@bash -c 'source .venv/bin/activate && python eval.py'
|
||||||
|
|
||||||
|
train_frozen_encoder: train.py training.csv
|
||||||
|
@echo "Training coordinate head with frozen encoder..."
|
||||||
|
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_encoder --freeze-encoder --epochs 50'
|
||||||
|
|
||||||
|
eval_frozen_encoder: eval.py training.csv
|
||||||
|
@echo "Evaluating frozen-encoder coordinate regressor..."
|
||||||
|
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_encoder --output-file predictions_frozen_encoder.csv --plot-file plots/prediction_map_frozen_encoder.png --scatter-plot-file plots/predicted_vs_actual_frozen_encoder.png'
|
||||||
|
|
||||||
|
train_frozen_layers: train.py training.csv
|
||||||
|
@echo "Training coordinate regressor with first transformer layers frozen..."
|
||||||
|
@bash -c 'source .venv/bin/activate && python train.py --output-path output_frozen_layers --freeze-transformer-layers 4'
|
||||||
|
|
||||||
|
eval_frozen_layers: eval.py training.csv
|
||||||
|
@echo "Evaluating frozen-layer coordinate regressor..."
|
||||||
|
@bash -c 'source .venv/bin/activate && python eval.py --model-path output_frozen_layers --output-file predictions_frozen_layers.csv --plot-file plots/prediction_map_frozen_layers.png --scatter-plot-file plots/predicted_vs_actual_frozen_layers.png'
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@echo "Auto-linting files and performing final style checks..."
|
@echo "Auto-linting files and performing final style checks..."
|
||||||
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
|
@bash -c 'source .venv/bin/activate && isort --profile=black *.py'
|
||||||
@ -25,4 +41,4 @@ clean:
|
|||||||
@rm -rf output/
|
@rm -rf output/
|
||||||
@rm -f training.csv predictions.csv
|
@rm -f training.csv predictions.csv
|
||||||
|
|
||||||
.PHONY: data train eval lint clean all
|
.PHONY: data train eval train_frozen_encoder eval_frozen_encoder train_frozen_layers eval_frozen_layers lint clean all
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--samples-per-intersection",
|
"--samples-per-intersection",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=100,
|
||||||
help="Bootstrap samples to create for each intersection.",
|
help="Bootstrap samples to create for each intersection.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
179
train.py
179
train.py
@ -32,6 +32,18 @@ class SignCoordinateDataset(Dataset):
|
|||||||
return self.texts[index], self.coordinates[index]
|
return self.texts[index], self.coordinates[index]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCoordinateDataset(Dataset):
|
||||||
|
def __init__(self, embeddings, coordinates):
|
||||||
|
self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
|
||||||
|
self.coordinates = torch.tensor(coordinates, dtype=torch.float32)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.embeddings)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.embeddings[index], self.coordinates[index]
|
||||||
|
|
||||||
|
|
||||||
class CoordinateRegressor(nn.Module):
|
class CoordinateRegressor(nn.Module):
|
||||||
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
|
def __init__(self, embedding_dim, hidden_dim=256, dropout=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -66,6 +78,25 @@ def parse_args():
|
|||||||
parser.add_argument("--test-size", type=float, default=0.2)
|
parser.add_argument("--test-size", type=float, default=0.2)
|
||||||
parser.add_argument("--hidden-dim", type=int, default=256)
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||||
parser.add_argument("--dropout", type=float, default=0.1)
|
parser.add_argument("--dropout", type=float, default=0.1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-encoder",
|
||||||
|
action="store_true",
|
||||||
|
help="Train only the coordinate head; keep the sentence encoder fixed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-transformer-layers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Freeze the first N transformer layers in the sentence encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-attention",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Freeze self-attention parameters while leaving other encoder "
|
||||||
|
"params trainable."
|
||||||
|
),
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -102,13 +133,70 @@ def collate_fn(model, device):
|
|||||||
return collate
|
return collate
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(encoder, head, dataloader, optimizer, loss_fn):
|
@torch.no_grad()
|
||||||
|
def encode_texts(encoder, texts, batch_size, device):
|
||||||
|
encoder.eval()
|
||||||
|
embeddings = []
|
||||||
|
for start in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[start : start + batch_size]
|
||||||
|
features = encoder.tokenize(batch)
|
||||||
|
features = {key: value.to(device) for key, value in features.items()}
|
||||||
|
batch_embeddings = encoder(features)["sentence_embedding"]
|
||||||
|
embeddings.append(batch_embeddings.cpu().numpy())
|
||||||
|
return np.vstack(embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
||||||
|
head.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
for embeddings, labels in dataloader:
|
||||||
|
embeddings = embeddings.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
predictions = head(embeddings)
|
||||||
|
loss = loss_fn(predictions, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
return total_loss / len(dataloader.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
|
||||||
|
head.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
errors_km = []
|
||||||
|
|
||||||
|
for embeddings, labels in dataloader:
|
||||||
|
embeddings = embeddings.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
predictions = head(embeddings)
|
||||||
|
loss = loss_fn(predictions, labels)
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
|
||||||
|
true_coords = labels.cpu().numpy() * coord_std + coord_mean
|
||||||
|
errors_km.extend(haversine_km(pred_coords, true_coords))
|
||||||
|
|
||||||
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable):
|
||||||
|
if encoder_trainable:
|
||||||
encoder.train()
|
encoder.train()
|
||||||
|
else:
|
||||||
|
encoder.eval()
|
||||||
head.train()
|
head.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
|
||||||
for features, labels in dataloader:
|
for features, labels in dataloader:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
if encoder_trainable:
|
||||||
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
embeddings = encoder(features)["sentence_embedding"]
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
predictions = head(embeddings)
|
predictions = head(embeddings)
|
||||||
loss = loss_fn(predictions, labels)
|
loss = loss_fn(predictions, labels)
|
||||||
@ -183,6 +271,43 @@ def save_initial_state(output_path, encoder, head, coord_mean, coord_std, args):
|
|||||||
json.dump(metadata, f, indent=2)
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_encoder_parts(encoder, args):
|
||||||
|
if args.freeze_encoder:
|
||||||
|
for parameter in encoder.parameters():
|
||||||
|
parameter.requires_grad = False
|
||||||
|
return
|
||||||
|
|
||||||
|
transformer = encoder[0].auto_model
|
||||||
|
if args.freeze_transformer_layers > 0:
|
||||||
|
layers = transformer.encoder.layer[: args.freeze_transformer_layers]
|
||||||
|
for layer in layers:
|
||||||
|
for parameter in layer.parameters():
|
||||||
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
if args.freeze_attention:
|
||||||
|
for name, parameter in transformer.named_parameters():
|
||||||
|
if ".attention." in name or name.startswith("attention."):
|
||||||
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
def count_trainable_parameters(module):
|
||||||
|
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||||
|
total = sum(p.numel() for p in module.parameters())
|
||||||
|
return trainable, total
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimizer(encoder, head, args):
|
||||||
|
parameter_groups = []
|
||||||
|
encoder_parameters = [p for p in encoder.parameters() if p.requires_grad]
|
||||||
|
if encoder_parameters:
|
||||||
|
group = {"params": encoder_parameters, "lr": args.learning_rate}
|
||||||
|
parameter_groups.append(group)
|
||||||
|
parameter_groups.append(
|
||||||
|
{"params": head.parameters(), "lr": args.head_learning_rate}
|
||||||
|
)
|
||||||
|
return torch.optim.AdamW(parameter_groups, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
@ -223,7 +348,34 @@ def main():
|
|||||||
dropout=args.dropout,
|
dropout=args.dropout,
|
||||||
).to(device)
|
).to(device)
|
||||||
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
|
save_initial_state(args.output_path, encoder, head, coord_mean, coord_std, args)
|
||||||
|
freeze_encoder_parts(encoder, args)
|
||||||
|
encoder_trainable, encoder_total = count_trainable_parameters(encoder)
|
||||||
|
head_trainable, head_total = count_trainable_parameters(head)
|
||||||
|
print(
|
||||||
|
f"Trainable encoder params: {encoder_trainable:,}/{encoder_total:,}; "
|
||||||
|
f"head params: {head_trainable:,}/{head_total:,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_trainable == 0:
|
||||||
|
print("Caching frozen encoder embeddings...")
|
||||||
|
all_embeddings = encode_texts(encoder, texts, args.batch_size, device)
|
||||||
|
train_dataset = EmbeddingCoordinateDataset(
|
||||||
|
all_embeddings[train_indices], normalized_coordinates[train_indices]
|
||||||
|
)
|
||||||
|
val_dataset = EmbeddingCoordinateDataset(
|
||||||
|
all_embeddings[val_indices], normalized_coordinates[val_indices]
|
||||||
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -237,13 +389,7 @@ def main():
|
|||||||
collate_fn=collate_fn(encoder, device),
|
collate_fn=collate_fn(encoder, device),
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = make_optimizer(encoder, head, args)
|
||||||
[
|
|
||||||
{"params": encoder.parameters(), "lr": args.learning_rate},
|
|
||||||
{"params": head.parameters(), "lr": args.head_learning_rate},
|
|
||||||
],
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
)
|
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
|
|
||||||
@ -252,7 +398,22 @@ def main():
|
|||||||
f"validating on {len(val_dataset):,} rows"
|
f"validating on {len(val_dataset):,} rows"
|
||||||
)
|
)
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, args.epochs + 1):
|
||||||
train_loss = train_epoch(encoder, head, train_loader, optimizer, loss_fn)
|
if encoder_trainable == 0:
|
||||||
|
train_loss = train_head_epoch(
|
||||||
|
head, train_loader, optimizer, loss_fn, device
|
||||||
|
)
|
||||||
|
val_loss, val_error_km = evaluate_head(
|
||||||
|
head, val_loader, loss_fn, coord_mean, coord_std, device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_loss = train_epoch(
|
||||||
|
encoder,
|
||||||
|
head,
|
||||||
|
train_loader,
|
||||||
|
optimizer,
|
||||||
|
loss_fn,
|
||||||
|
True,
|
||||||
|
)
|
||||||
val_loss, val_error_km = evaluate(
|
val_loss, val_error_km = evaluate(
|
||||||
encoder, head, val_loader, loss_fn, coord_mean, coord_std
|
encoder, head, val_loader, loss_fn, coord_mean, coord_std
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user