batchsize

This commit is contained in:
mm 2023-05-05 00:50:57 +00:00
parent f75c99cb4e
commit 948c337ec2

View File

@ -55,28 +55,28 @@ train_examples, val_examples = train_test_split(
# validation examples can be something like templated sentences # validation examples can be something like templated sentences
# that maintain the same distance as the cities (same context) # that maintain the same distance as the cities (same context)
# should probably add training examples like that too if needed # should probably add training examples like that too if needed
batch_size = 16 BATCH_SIZE = 48
num_examples = len(train_examples) num_examples = len(train_examples)
steps_per_epoch = num_examples // batch_size steps_per_epoch = num_examples // BATCH_SIZE
print(f"\nHead of training data (size: {num_examples}):") print(f"\nHead of training data (size: {num_examples}):")
print(train_data[:10], "\n") print(train_data[:10], "\n")
# Create DataLoaders for train and validation datasets # Create DataLoaders for train and validation datasets
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
print("TRAINING") print("TRAINING")
# Configure the training arguments # Configure the training arguments
training_args = { training_args = {
"output_path": "./output", "output_path": "./output",
# "evaluation_steps": steps_per_epoch, # already evaluates at the end of each epoch # "evaluation_steps": steps_per_epoch, # already evaluates at the end of each epoch
"epochs": 20, "epochs": 10,
"warmup_steps": 500, "warmup_steps": 500,
"optimizer_params": {"lr": 2e-5}, "optimizer_params": {"lr": 2e-5},
# "weight_decay": 0, # not sure if this helps but works fine without setting it. # "weight_decay": 0, # not sure if this helps but works fine without setting it.
"scheduler": "WarmupLinear", "scheduler": "WarmupLinear",
"save_best_model": True, "save_best_model": True,
"checkpoint_path": "./checkpoints_absmax_split", "checkpoint_path": "./checkpoints",
"checkpoint_save_steps": steps_per_epoch, "checkpoint_save_steps": steps_per_epoch,
"checkpoint_save_total_limit": 100, "checkpoint_save_total_limit": 100,
} }