From 8f4d4c10575aeb2290a941fb9c79a57ba1cfb2f4 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 25 May 2026 21:45:08 +0000 Subject: [PATCH] optimizations? --- train.py | 223 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 187 insertions(+), 36 deletions(-) diff --git a/train.py b/train.py index 37be00e..4adb10d 100644 --- a/train.py +++ b/train.py @@ -75,7 +75,22 @@ def parse_args(): ) parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--epochs", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="DataLoader workers to prefetch batches while the GPU trains.", + ) + parser.add_argument( + "--save-every-epochs", + type=int, + default=5, + help=( + "Write the best checkpoint to disk every N epochs (and at the final " + "epoch). Validation still runs every epoch." + ), + ) parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--head-learning-rate", type=float, default=1e-3) parser.add_argument("--weight-decay", type=float, default=0.01) @@ -127,16 +142,83 @@ def normalize_coordinates(coordinates): return (coordinates - mean) / std, mean, std -def collate_fn(model, device): +def move_features_to_device(features, device, non_blocking=False): + return { + key: value.to(device, non_blocking=non_blocking) + for key, value in features.items() + } + + +def move_batch_to_device(features, labels, device, pin_memory=False): + non_blocking = pin_memory and device.type == "cuda" + if pin_memory and device.type == "cuda": + features = {key: value.pin_memory() for key, value in features.items()} + labels = labels.pin_memory() + return ( + move_features_to_device(features, device, non_blocking=non_blocking), + labels.to(device, non_blocking=non_blocking), + ) + + +def move_tensors_to_device(tensors, device, pin_memory=False): + non_blocking = pin_memory and device.type == "cuda" + if pin_memory and device.type == "cuda": + tensors = [tensor.pin_memory() for tensor in tensors] + return [tensor.to(device, non_blocking=non_blocking) for tensor in tensors] + + +def make_text_collate(tokenize): def collate(batch): texts, labels = zip(*batch) - features = model.tokenize(list(texts)) - features = {key: value.to(device) for key, value in features.items()} - return features, torch.stack(labels).to(device) + features = tokenize(list(texts)) + return features, torch.stack(labels) return collate +def embedding_collate(batch): + embeddings, labels = zip(*batch) + return torch.stack(embeddings), torch.stack(labels) + + +def make_dataloader(dataset, batch_size, shuffle, collate_fn, num_workers, pin_memory): + loader_kwargs = { + "dataset": dataset, + "batch_size": batch_size, + "shuffle": shuffle, + "collate_fn": collate_fn, + "num_workers": num_workers, + "pin_memory": pin_memory, + } + if num_workers > 0: + loader_kwargs["persistent_workers"] = True + return DataLoader(**loader_kwargs) + + +def copy_module_state(module): + return {key: value.detach().cpu() for key, value in module.state_dict().items()} + + +def save_best_checkpoint( + output_path, encoder, head, best_states, coord_mean, coord_std, args +): + encoder_state = encoder.state_dict() + head_state = head.state_dict() + encoder.load_state_dict(best_states["encoder"]) + head.load_state_dict(best_states["head"]) + save_model(output_path, encoder, head, coord_mean, coord_std, args) + encoder.load_state_dict(encoder_state) + head.load_state_dict(head_state) + + +def should_save_checkpoint(epoch, total_epochs, save_every_epochs, pending_save): + if not pending_save: + return False + if epoch == total_epochs: + return True + return epoch % save_every_epochs == 0 + + @torch.no_grad() def encode_texts(encoder, texts, batch_size, device): encoder.eval() @@ -153,11 +235,13 @@ def encode_texts(encoder, texts, batch_size, device): def train_head_epoch(head, dataloader, optimizer, loss_fn, device): head.train() total_loss = 0.0 + pin_memory = dataloader.pin_memory for embeddings, labels in dataloader: - embeddings = embeddings.to(device) - labels = labels.to(device) - optimizer.zero_grad() + embeddings, labels = move_tensors_to_device( + [embeddings, labels], device, pin_memory=pin_memory + ) + optimizer.zero_grad(set_to_none=True) predictions = head(embeddings) loss = loss_fn(predictions, labels) loss.backward() @@ -171,32 +255,49 @@ def train_head_epoch(head, dataloader, optimizer, loss_fn, device): def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device): head.eval() total_loss = 0.0 - errors_km = [] + predictions_all = [] + labels_all = [] + pin_memory = dataloader.pin_memory for embeddings, labels in dataloader: - embeddings = embeddings.to(device) - labels = labels.to(device) + embeddings, labels = move_tensors_to_device( + [embeddings, labels], device, pin_memory=pin_memory + ) predictions = head(embeddings) loss = loss_fn(predictions, labels) total_loss += loss.item() * labels.size(0) + predictions_all.append(predictions) + labels_all.append(labels) - pred_coords = predictions.cpu().numpy() * coord_std + coord_mean - true_coords = labels.cpu().numpy() * coord_std + coord_mean - errors_km.extend(haversine_km(pred_coords, true_coords)) + pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean + true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean + errors_km = haversine_km(pred_coords, true_coords) return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) -def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable): +def train_epoch( + encoder, + head, + dataloader, + optimizer, + loss_fn, + device, + encoder_trainable, +): if encoder_trainable: encoder.train() else: encoder.eval() head.train() total_loss = 0.0 + pin_memory = dataloader.pin_memory for features, labels in dataloader: - optimizer.zero_grad() + features, labels = move_batch_to_device( + features, labels, device, pin_memory=pin_memory + ) + optimizer.zero_grad(set_to_none=True) if encoder_trainable: embeddings = encoder(features)["sentence_embedding"] else: @@ -212,21 +313,28 @@ def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable @torch.no_grad() -def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std): +def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std, device): encoder.eval() head.eval() total_loss = 0.0 - errors_km = [] + predictions_all = [] + labels_all = [] + pin_memory = dataloader.pin_memory for features, labels in dataloader: + features, labels = move_batch_to_device( + features, labels, device, pin_memory=pin_memory + ) embeddings = encoder(features)["sentence_embedding"] predictions = head(embeddings) loss = loss_fn(predictions, labels) total_loss += loss.item() * labels.size(0) + predictions_all.append(predictions) + labels_all.append(labels) - pred_coords = predictions.cpu().numpy() * coord_std + coord_mean - true_coords = labels.cpu().numpy() * coord_std + coord_mean - errors_km.extend(haversine_km(pred_coords, true_coords)) + pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean + true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean + errors_km = haversine_km(pred_coords, true_coords) return total_loss / len(dataloader.dataset), float(np.mean(errors_km)) @@ -316,6 +424,7 @@ def main(): args = parse_args() set_seed(args.seed) device = get_device(args.device) + pin_memory = device.type == "cuda" print(f"Using device: {device}") data = pd.read_csv(args.data_file) @@ -370,37 +479,51 @@ def main(): val_dataset = EmbeddingCoordinateDataset( all_embeddings[val_indices], normalized_coordinates[val_indices] ) - train_loader = DataLoader( + train_loader = make_dataloader( train_dataset, - batch_size=args.batch_size, + args.batch_size, shuffle=True, + collate_fn=embedding_collate, + num_workers=args.num_workers, + pin_memory=pin_memory, ) - val_loader = DataLoader( + val_loader = make_dataloader( val_dataset, - batch_size=args.batch_size, + args.batch_size, shuffle=False, + collate_fn=embedding_collate, + num_workers=args.num_workers, + pin_memory=pin_memory, ) else: - train_loader = DataLoader( + text_collate = make_text_collate(encoder.tokenize) + train_loader = make_dataloader( train_dataset, - batch_size=args.batch_size, + args.batch_size, shuffle=True, - collate_fn=collate_fn(encoder, device), + collate_fn=text_collate, + num_workers=args.num_workers, + pin_memory=pin_memory, ) - val_loader = DataLoader( + val_loader = make_dataloader( val_dataset, - batch_size=args.batch_size, + args.batch_size, shuffle=False, - collate_fn=collate_fn(encoder, device), + collate_fn=text_collate, + num_workers=args.num_workers, + pin_memory=pin_memory, ) optimizer = make_optimizer(encoder, head, args) loss_fn = nn.MSELoss() best_val_loss = float("inf") + best_states = None + pending_save = False print( f"Training on {len(train_dataset):,} rows; " - f"validating on {len(val_dataset):,} rows" + f"validating on {len(val_dataset):,} rows; " + f"batch_size={args.batch_size}; num_workers={args.num_workers}" ) for epoch in range(1, args.epochs + 1): if encoder_trainable == 0: @@ -417,10 +540,17 @@ def main(): train_loader, optimizer, loss_fn, - True, + device, + encoder_trainable > 0, ) val_loss, val_error_km = evaluate( - encoder, head, val_loader, loss_fn, coord_mean, coord_std + encoder, + head, + val_loader, + loss_fn, + coord_mean, + coord_std, + device, ) print( f"epoch={epoch} train_loss={train_loss:.6f} " @@ -429,8 +559,29 @@ def main(): if val_loss < best_val_loss: best_val_loss = val_loss - save_model(args.output_path, encoder, head, coord_mean, coord_std, args) - print(f"Saved best model to {args.output_path}") + best_states = { + "encoder": copy_module_state(encoder), + "head": copy_module_state(head), + } + pending_save = True + + if should_save_checkpoint( + epoch, args.epochs, args.save_every_epochs, pending_save + ): + save_best_checkpoint( + args.output_path, + encoder, + head, + best_states, + coord_mean, + coord_std, + args, + ) + pending_save = False + print( + f"Saved best model to {args.output_path} " + f"(val_loss={best_val_loss:.6f})" + ) if __name__ == "__main__":