optimizations?
This commit is contained in:
parent
44c7753856
commit
8f4d4c1057
223
train.py
223
train.py
@ -75,7 +75,22 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=1992)
|
parser.add_argument("--seed", type=int, default=1992)
|
||||||
parser.add_argument("--epochs", type=int, default=10)
|
parser.add_argument("--epochs", type=int, default=10)
|
||||||
parser.add_argument("--batch-size", type=int, default=32)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="DataLoader workers to prefetch batches while the GPU trains.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-epochs",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help=(
|
||||||
|
"Write the best checkpoint to disk every N epochs (and at the final "
|
||||||
|
"epoch). Validation still runs every epoch."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
||||||
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
|
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
|
||||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||||
@ -127,16 +142,83 @@ def normalize_coordinates(coordinates):
|
|||||||
return (coordinates - mean) / std, mean, std
|
return (coordinates - mean) / std, mean, std
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(model, device):
|
def move_features_to_device(features, device, non_blocking=False):
|
||||||
|
return {
|
||||||
|
key: value.to(device, non_blocking=non_blocking)
|
||||||
|
for key, value in features.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def move_batch_to_device(features, labels, device, pin_memory=False):
|
||||||
|
non_blocking = pin_memory and device.type == "cuda"
|
||||||
|
if pin_memory and device.type == "cuda":
|
||||||
|
features = {key: value.pin_memory() for key, value in features.items()}
|
||||||
|
labels = labels.pin_memory()
|
||||||
|
return (
|
||||||
|
move_features_to_device(features, device, non_blocking=non_blocking),
|
||||||
|
labels.to(device, non_blocking=non_blocking),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def move_tensors_to_device(tensors, device, pin_memory=False):
|
||||||
|
non_blocking = pin_memory and device.type == "cuda"
|
||||||
|
if pin_memory and device.type == "cuda":
|
||||||
|
tensors = [tensor.pin_memory() for tensor in tensors]
|
||||||
|
return [tensor.to(device, non_blocking=non_blocking) for tensor in tensors]
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_collate(tokenize):
|
||||||
def collate(batch):
|
def collate(batch):
|
||||||
texts, labels = zip(*batch)
|
texts, labels = zip(*batch)
|
||||||
features = model.tokenize(list(texts))
|
features = tokenize(list(texts))
|
||||||
features = {key: value.to(device) for key, value in features.items()}
|
return features, torch.stack(labels)
|
||||||
return features, torch.stack(labels).to(device)
|
|
||||||
|
|
||||||
return collate
|
return collate
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_collate(batch):
|
||||||
|
embeddings, labels = zip(*batch)
|
||||||
|
return torch.stack(embeddings), torch.stack(labels)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataloader(dataset, batch_size, shuffle, collate_fn, num_workers, pin_memory):
|
||||||
|
loader_kwargs = {
|
||||||
|
"dataset": dataset,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"shuffle": shuffle,
|
||||||
|
"collate_fn": collate_fn,
|
||||||
|
"num_workers": num_workers,
|
||||||
|
"pin_memory": pin_memory,
|
||||||
|
}
|
||||||
|
if num_workers > 0:
|
||||||
|
loader_kwargs["persistent_workers"] = True
|
||||||
|
return DataLoader(**loader_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_module_state(module):
|
||||||
|
return {key: value.detach().cpu() for key, value in module.state_dict().items()}
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_checkpoint(
|
||||||
|
output_path, encoder, head, best_states, coord_mean, coord_std, args
|
||||||
|
):
|
||||||
|
encoder_state = encoder.state_dict()
|
||||||
|
head_state = head.state_dict()
|
||||||
|
encoder.load_state_dict(best_states["encoder"])
|
||||||
|
head.load_state_dict(best_states["head"])
|
||||||
|
save_model(output_path, encoder, head, coord_mean, coord_std, args)
|
||||||
|
encoder.load_state_dict(encoder_state)
|
||||||
|
head.load_state_dict(head_state)
|
||||||
|
|
||||||
|
|
||||||
|
def should_save_checkpoint(epoch, total_epochs, save_every_epochs, pending_save):
|
||||||
|
if not pending_save:
|
||||||
|
return False
|
||||||
|
if epoch == total_epochs:
|
||||||
|
return True
|
||||||
|
return epoch % save_every_epochs == 0
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode_texts(encoder, texts, batch_size, device):
|
def encode_texts(encoder, texts, batch_size, device):
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
@ -153,11 +235,13 @@ def encode_texts(encoder, texts, batch_size, device):
|
|||||||
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
||||||
head.train()
|
head.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
pin_memory = dataloader.pin_memory
|
||||||
|
|
||||||
for embeddings, labels in dataloader:
|
for embeddings, labels in dataloader:
|
||||||
embeddings = embeddings.to(device)
|
embeddings, labels = move_tensors_to_device(
|
||||||
labels = labels.to(device)
|
[embeddings, labels], device, pin_memory=pin_memory
|
||||||
optimizer.zero_grad()
|
)
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
predictions = head(embeddings)
|
predictions = head(embeddings)
|
||||||
loss = loss_fn(predictions, labels)
|
loss = loss_fn(predictions, labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -171,32 +255,49 @@ def train_head_epoch(head, dataloader, optimizer, loss_fn, device):
|
|||||||
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
|
def evaluate_head(head, dataloader, loss_fn, coord_mean, coord_std, device):
|
||||||
head.eval()
|
head.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
errors_km = []
|
predictions_all = []
|
||||||
|
labels_all = []
|
||||||
|
pin_memory = dataloader.pin_memory
|
||||||
|
|
||||||
for embeddings, labels in dataloader:
|
for embeddings, labels in dataloader:
|
||||||
embeddings = embeddings.to(device)
|
embeddings, labels = move_tensors_to_device(
|
||||||
labels = labels.to(device)
|
[embeddings, labels], device, pin_memory=pin_memory
|
||||||
|
)
|
||||||
predictions = head(embeddings)
|
predictions = head(embeddings)
|
||||||
loss = loss_fn(predictions, labels)
|
loss = loss_fn(predictions, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
predictions_all.append(predictions)
|
||||||
|
labels_all.append(labels)
|
||||||
|
|
||||||
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
|
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
|
||||||
true_coords = labels.cpu().numpy() * coord_std + coord_mean
|
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
|
||||||
errors_km.extend(haversine_km(pred_coords, true_coords))
|
errors_km = haversine_km(pred_coords, true_coords)
|
||||||
|
|
||||||
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable):
|
def train_epoch(
|
||||||
|
encoder,
|
||||||
|
head,
|
||||||
|
dataloader,
|
||||||
|
optimizer,
|
||||||
|
loss_fn,
|
||||||
|
device,
|
||||||
|
encoder_trainable,
|
||||||
|
):
|
||||||
if encoder_trainable:
|
if encoder_trainable:
|
||||||
encoder.train()
|
encoder.train()
|
||||||
else:
|
else:
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
head.train()
|
head.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
pin_memory = dataloader.pin_memory
|
||||||
|
|
||||||
for features, labels in dataloader:
|
for features, labels in dataloader:
|
||||||
optimizer.zero_grad()
|
features, labels = move_batch_to_device(
|
||||||
|
features, labels, device, pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
if encoder_trainable:
|
if encoder_trainable:
|
||||||
embeddings = encoder(features)["sentence_embedding"]
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
else:
|
else:
|
||||||
@ -212,21 +313,28 @@ def train_epoch(encoder, head, dataloader, optimizer, loss_fn, encoder_trainable
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std):
|
def evaluate(encoder, head, dataloader, loss_fn, coord_mean, coord_std, device):
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
head.eval()
|
head.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
errors_km = []
|
predictions_all = []
|
||||||
|
labels_all = []
|
||||||
|
pin_memory = dataloader.pin_memory
|
||||||
|
|
||||||
for features, labels in dataloader:
|
for features, labels in dataloader:
|
||||||
|
features, labels = move_batch_to_device(
|
||||||
|
features, labels, device, pin_memory=pin_memory
|
||||||
|
)
|
||||||
embeddings = encoder(features)["sentence_embedding"]
|
embeddings = encoder(features)["sentence_embedding"]
|
||||||
predictions = head(embeddings)
|
predictions = head(embeddings)
|
||||||
loss = loss_fn(predictions, labels)
|
loss = loss_fn(predictions, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
predictions_all.append(predictions)
|
||||||
|
labels_all.append(labels)
|
||||||
|
|
||||||
pred_coords = predictions.cpu().numpy() * coord_std + coord_mean
|
pred_coords = torch.cat(predictions_all).float().cpu().numpy() * coord_std + coord_mean
|
||||||
true_coords = labels.cpu().numpy() * coord_std + coord_mean
|
true_coords = torch.cat(labels_all).float().cpu().numpy() * coord_std + coord_mean
|
||||||
errors_km.extend(haversine_km(pred_coords, true_coords))
|
errors_km = haversine_km(pred_coords, true_coords)
|
||||||
|
|
||||||
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
return total_loss / len(dataloader.dataset), float(np.mean(errors_km))
|
||||||
|
|
||||||
@ -316,6 +424,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
device = get_device(args.device)
|
device = get_device(args.device)
|
||||||
|
pin_memory = device.type == "cuda"
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
data = pd.read_csv(args.data_file)
|
data = pd.read_csv(args.data_file)
|
||||||
@ -370,37 +479,51 @@ def main():
|
|||||||
val_dataset = EmbeddingCoordinateDataset(
|
val_dataset = EmbeddingCoordinateDataset(
|
||||||
all_embeddings[val_indices], normalized_coordinates[val_indices]
|
all_embeddings[val_indices], normalized_coordinates[val_indices]
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = make_dataloader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batch_size,
|
args.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
collate_fn=embedding_collate,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = make_dataloader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batch_size,
|
args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
collate_fn=embedding_collate,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_loader = DataLoader(
|
text_collate = make_text_collate(encoder.tokenize)
|
||||||
|
train_loader = make_dataloader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batch_size,
|
args.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn(encoder, device),
|
collate_fn=text_collate,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = make_dataloader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batch_size,
|
args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=collate_fn(encoder, device),
|
collate_fn=text_collate,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = make_optimizer(encoder, head, args)
|
optimizer = make_optimizer(encoder, head, args)
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
|
best_states = None
|
||||||
|
pending_save = False
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Training on {len(train_dataset):,} rows; "
|
f"Training on {len(train_dataset):,} rows; "
|
||||||
f"validating on {len(val_dataset):,} rows"
|
f"validating on {len(val_dataset):,} rows; "
|
||||||
|
f"batch_size={args.batch_size}; num_workers={args.num_workers}"
|
||||||
)
|
)
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, args.epochs + 1):
|
||||||
if encoder_trainable == 0:
|
if encoder_trainable == 0:
|
||||||
@ -417,10 +540,17 @@ def main():
|
|||||||
train_loader,
|
train_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
True,
|
device,
|
||||||
|
encoder_trainable > 0,
|
||||||
)
|
)
|
||||||
val_loss, val_error_km = evaluate(
|
val_loss, val_error_km = evaluate(
|
||||||
encoder, head, val_loader, loss_fn, coord_mean, coord_std
|
encoder,
|
||||||
|
head,
|
||||||
|
val_loader,
|
||||||
|
loss_fn,
|
||||||
|
coord_mean,
|
||||||
|
coord_std,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"epoch={epoch} train_loss={train_loss:.6f} "
|
f"epoch={epoch} train_loss={train_loss:.6f} "
|
||||||
@ -429,8 +559,29 @@ def main():
|
|||||||
|
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
save_model(args.output_path, encoder, head, coord_mean, coord_std, args)
|
best_states = {
|
||||||
print(f"Saved best model to {args.output_path}")
|
"encoder": copy_module_state(encoder),
|
||||||
|
"head": copy_module_state(head),
|
||||||
|
}
|
||||||
|
pending_save = True
|
||||||
|
|
||||||
|
if should_save_checkpoint(
|
||||||
|
epoch, args.epochs, args.save_every_epochs, pending_save
|
||||||
|
):
|
||||||
|
save_best_checkpoint(
|
||||||
|
args.output_path,
|
||||||
|
encoder,
|
||||||
|
head,
|
||||||
|
best_states,
|
||||||
|
coord_mean,
|
||||||
|
coord_std,
|
||||||
|
args,
|
||||||
|
)
|
||||||
|
pending_save = False
|
||||||
|
print(
|
||||||
|
f"Saved best model to {args.output_path} "
|
||||||
|
f"(val_loss={best_val_loss:.6f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user