You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
1.9 KiB
68 lines
1.9 KiB
11 months ago
|
import argparse
|
||
|
|
||
|
import pytorch_lightning as pl
|
||
|
|
||
|
from dataloader import create_named_dataloader as init_data
|
||
|
from model import ColorTransformerModel
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# Define argument parser
|
||
|
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
|
||
|
|
||
|
# Add arguments
|
||
|
parser.add_argument(
|
||
|
"-bs",
|
||
|
"--batch_size",
|
||
|
type=int,
|
||
|
default=64,
|
||
|
help="Input batch size for training",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--num_workers", type=int, default=3, help="Number of workers for data loading"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-ds",
|
||
|
"--distinct_threshold",
|
||
|
type=float,
|
||
|
default=0.5,
|
||
|
help="Threshold for color distinctness penalty",
|
||
|
)
|
||
|
|
||
|
# Parse arguments
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Initialize data loader with parsed arguments
|
||
|
train_dataloader = init_data(
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=args.num_workers,
|
||
|
)
|
||
|
|
||
|
# Initialize model with parsed arguments
|
||
|
model = ColorTransformerModel(
|
||
|
alpha=args.alpha,
|
||
|
distinct_threshold=args.distinct_threshold,
|
||
|
learning_rate=args.learning_rate,
|
||
|
)
|
||
|
|
||
|
# Initialize trainer with parsed arguments
|
||
|
trainer = pl.Trainer(
|
||
|
max_epochs=args.max_epochs,
|
||
|
log_every_n_steps=args.log_every_n_steps,
|
||
|
)
|
||
|
|
||
|
# Train the model
|
||
|
trainer.fit(model, train_dataloader)
|