import argparse import pytorch_lightning as pl from dataloader import create_named_dataloader as init_data from model import ColorTransformerModel def parse_args(): # Define argument parser parser = argparse.ArgumentParser(description="Color Transformer Training Script") # Add arguments parser.add_argument( "--bs", type=int, default=64, help="Input batch size for training", ) parser.add_argument( "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" ) parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") parser.add_argument( "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" ) parser.add_argument( "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency" ) parser.add_argument( "-w", "--num_workers", type=int, default=3, help="Number of workers for data loading", ) # Parse arguments args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() # Initialize data loader with parsed arguments train_dataloader = init_data( batch_size=args.bs, shuffle=True, num_workers=args.num_workers, ) # Initialize model with parsed arguments model = ColorTransformerModel( alpha=args.alpha, learning_rate=args.lr, ) # Initialize trainer with parsed arguments trainer = pl.Trainer( max_epochs=args.max_epochs, log_every_n_steps=args.log_every_n_steps, ) # Train the model trainer.fit(model, train_dataloader)