optimizations?

This commit is contained in:
Michael Pilosov 2026-05-25 21:45:08 +00:00
parent 44c7753856
commit 8f4d4c1057

223
train.py
View File

@ -75,7 +75,22 @@ def parse_args():
) )
parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--seed", type=int, default=1992)
parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="DataLoader workers to prefetch batches while the GPU trains.",
)
parser.add_argument(
"--save-every-epochs",
type=int,
default=5,
help=(
"Write the best checkpoint to disk every N epochs (and at the final "
"epoch). Validation still runs every epoch."
),
)
parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--learning-rate", type=float, default=2e-5)
parser.add_argument("--head-learning-rate", type=float, default=1e-3) parser.add_argument("--head-learning-rate", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--weight-decay", type=float, default=0.01)
@ -127,16 +142,83 @@ def normalize_coordinates(coordinates):
return (coordinates - mean) / std, mean, std return (coordinates - mean) / std, mean, std
def collate_fn(model, device): def move_features_to_device(features, device, non_blocking=False):
return {
key: value.to(device, non_blocking=non_blocking)
for key, value in features.items()
}
def move_batch_to_device(features, labels, device, pin_memory=False):
non_blocking = pin_memory and device.type == "cuda"
if pin_memory and device.type == "cuda":
features = {key: value.pin_memory() for key, value in features.items()}
labels = labels.pin_memory()
return (
move_features_to_device(features, device, non_blocking=non_blocking),
labels.to(device, non_blocking=non_blocking),
)
def move_tensors_to_device(tensors, device, pin_memory=False):
non_blocking = pin_memory and device.type == "cuda"
if pin_memory and device.type == "cuda":
tensors = [tensor.pin_memory() for tensor in tensors]
return [tensor.to(device, non_blocking=non_blocking) for tensor in tensors]
def make_text_collate(tokenize):
def collate(batch): def collate(batch):
texts, labels = zip(*batch) texts, labels = zip(*batch)
features = model.tokenize(list(texts)) features = tokenize(list(texts))
features = {key: value.to(device) for key, value in features.items()} return features, torch.stack(labels)
return features, torch.stack(labels).to(device)
return collate return collate
def embedding_collate(batch):
embeddings, labels = zip(*batch)
return torch.stack(embeddings), torch.stack(labels)
def make_dataloader(dataset, batch_size, shuffle, collate_fn, num_workers, pin_memory):
loader_kwargs = {
"dataset": dataset,
"batch_size": batch_size,
"shuffle": shuffle,
"collate_fn": collate_fn,
"num_workers": num_workers,
"pin_memory": pin_memory,
}
if num_workers > 0:
loader_kwargs["persistent_workers"] = True
return DataLoader(**loader_kwargs)
def copy_module_state(module):
return {key: value.detach().cpu() for key, value in module.state_dict().items()}
def save_best_checkpoint(
output_path, encoder, head, best_states, coord_mean, coord_std, args
):
encoder_state = encoder.state_dict()
head_state = head.state_dict()
encoder.load_state_dict(best_states["encoder"])
head.load_state_dict(best_states["head"])
save_model(output_path, encoder, head, coord_mean, coord_std, args)
encoder.load_state_dict(encoder_state)
head.load_state_dict(head_state)
def should_save_checkpoint(epoch, total_epochs, save_every_epochs, pending_save):
if not pending_save:
return False
if epoch == total_epochs:
return True
return epoch % save_every_epochs == 0
@torch.no_grad() @torch.no_grad()
def encode_texts(encoder, texts, batch_size, device): def encode_texts(encoder, texts, batch_size, device):
encoder.eval() encoder.eval()
@ -153,11 +235,13 @@ def encode_texts(encoder, texts, batch_size, device):
def train_head_epoch(head, dataloader, optimizer, loss_fn, device): def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
head.train() head.train()
total_loss = 0.0 total_loss = 0.0
pin_memory = dataloader.pin_memory
for embeddings, labels in dataloader: for embeddings, labels in dataloader:
embeddings = embeddings.to(device) embeddings, labels = move_tensors_to_device(
labels = labels.to(device) [embeddings, labels], device, pin_memory=pin_memory
optimizer.zero_grad() )
optimizer.zero_grad(set_to_none=True)
predictions = head(embeddings) predictions = head(embeddings)
loss = loss_fn(predictions, labels) loss = loss_fn(predictions, labels)
loss.backward() loss.backward()
@ -171,32 +255,49 @@ def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device): def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
head.eval() head.eval()
total_loss = 0.0 total_loss = 0.0
errors_km = [] predictions_all = []
labels_all = []
pin_memory = dataloader.pin_memory
for embeddings, labels in dataloader: for embeddings, labels in dataloader:
embeddings = embeddings.to(device) embeddings, labels = move_tensors_to_device(
labels = labels.to(device) [embeddings, labels], device, pin_memory=pin_memory
)
predictions = head(embeddings) predictions = head(embeddings)
loss = loss_fn(predictions, labels) loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0) total_loss += loss.item() * labels.size(0)
predictions_all.append(predictions)
labels_all.append(labels)
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
true_coords = labels.cpu().numpy() * coord_std + coord_mean true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
errors_km.extend(haversine_km(pred_coords, true_coords)) errors_km = haversine_km(pred_coords, true_coords)
return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable): def train_epoch(
encoder,
head,
dataloader,
optimizer,
loss_fn,
device,
encoder_trainable,
):
if encoder_trainable: if encoder_trainable:
encoder.train() encoder.train()
else: else:
encoder.eval() encoder.eval()
head.train() head.train()
total_loss = 0.0 total_loss = 0.0
pin_memory = dataloader.pin_memory
for features, labels in dataloader: for features, labels in dataloader:
optimizer.zero_grad() features, labels = move_batch_to_device(
features, labels, device, pin_memory=pin_memory
)
optimizer.zero_grad(set_to_none=True)
if encoder_trainable: if encoder_trainable:
embeddings = encoder(features)["sentence_embedding"] embeddings = encoder(features)["sentence_embedding"]
else: else:
@ -212,21 +313,28 @@ def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable
@torch.no_grad() @torch.no_grad()
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std): def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std, device):
encoder.eval() encoder.eval()
head.eval() head.eval()
total_loss = 0.0 total_loss = 0.0
errors_km = [] predictions_all = []
labels_all = []
pin_memory = dataloader.pin_memory
for features, labels in dataloader: for features, labels in dataloader:
features, labels = move_batch_to_device(
features, labels, device, pin_memory=pin_memory
)
embeddings = encoder(features)["sentence_embedding"] embeddings = encoder(features)["sentence_embedding"]
predictions = head(embeddings) predictions = head(embeddings)
loss = loss_fn(predictions, labels) loss = loss_fn(predictions, labels)
total_loss += loss.item() * labels.size(0) total_loss += loss.item() * labels.size(0)
predictions_all.append(predictions)
labels_all.append(labels)
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
true_coords = labels.cpu().numpy() * coord_std + coord_mean true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
errors_km.extend(haversine_km(pred_coords, true_coords)) errors_km = haversine_km(pred_coords, true_coords)
return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
@ -316,6 +424,7 @@ def main():
args = parse_args() args = parse_args()
set_seed(args.seed) set_seed(args.seed)
device = get_device(args.device) device = get_device(args.device)
pin_memory = device.type == "cuda"
print(f"Using device: {device}") print(f"Using device: {device}")
data = pd.read_csv(args.data_file) data = pd.read_csv(args.data_file)
@ -370,37 +479,51 @@ def main():
val_dataset = EmbeddingCoordinateDataset( val_dataset = EmbeddingCoordinateDataset(
all_embeddings[val_indices], normalized_coordinates[val_indices] all_embeddings[val_indices], normalized_coordinates[val_indices]
) )
train_loader = DataLoader( train_loader = make_dataloader(
train_dataset, train_dataset,
batch_size=args.batch_size, args.batch_size,
shuffle=True, shuffle=True,
collate_fn=embedding_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
) )
val_loader = DataLoader( val_loader = make_dataloader(
val_dataset, val_dataset,
batch_size=args.batch_size, args.batch_size,
shuffle=False, shuffle=False,
collate_fn=embedding_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
) )
else: else:
train_loader = DataLoader( text_collate = make_text_collate(encoder.tokenize)
train_loader = make_dataloader(
train_dataset, train_dataset,
batch_size=args.batch_size, args.batch_size,
shuffle=True, shuffle=True,
collate_fn=collate_fn(encoder, device), collate_fn=text_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
) )
val_loader = DataLoader( val_loader = make_dataloader(
val_dataset, val_dataset,
batch_size=args.batch_size, args.batch_size,
shuffle=False, shuffle=False,
collate_fn=collate_fn(encoder, device), collate_fn=text_collate,
num_workers=args.num_workers,
pin_memory=pin_memory,
) )
optimizer = make_optimizer(encoder, head, args) optimizer = make_optimizer(encoder, head, args)
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
best_val_loss = float("inf") best_val_loss = float("inf")
best_states = None
pending_save = False
print( print(
f"Training on {len(train_dataset):,} rows; " f"Training on {len(train_dataset):,} rows; "
f"validating on {len(val_dataset):,} rows" f"validating on {len(val_dataset):,} rows; "
f"batch_size={args.batch_size}; num_workers={args.num_workers}"
) )
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
if encoder_trainable == 0: if encoder_trainable == 0:
@ -417,10 +540,17 @@ def main():
train_loader, train_loader,
optimizer, optimizer,
loss_fn, loss_fn,
True, device,
encoder_trainable > 0,
) )
val_loss, val_error_km = evaluate( val_loss, val_error_km = evaluate(
encoder, head, val_loader, loss_fn, coord_mean, coord_std encoder,
head,
val_loader,
loss_fn,
coord_mean,
coord_std,
device,
) )
print( print(
f"epoch={epoch} train_loss={train_loss:.6f} " f"epoch={epoch} train_loss={train_loss:.6f} "
@ -429,8 +559,29 @@ def main():
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
save_model(args.output_path, encoder, head, coord_mean, coord_std, args) best_states = {
print(f"Saved best model to {args.output_path}") "encoder": copy_module_state(encoder),
"head": copy_module_state(head),
}
pending_save = True
if should_save_checkpoint(
epoch, args.epochs, args.save_every_epochs, pending_save
):
save_best_checkpoint(
args.output_path,
encoder,
head,
best_states,
coord_mean,
coord_std,
args,
)
pending_save = False
print(
f"Saved best model to {args.output_path} "
f"(val_loss={best_val_loss:.6f})"
)
if __name__ == "__main__": if __name__ == "__main__":