import argparse import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping from callbacks import SaveImageCallback from dataloader import create_named_dataloader from model import ColorTransformerModel def parse_args(): # Define argument parser parser = argparse.ArgumentParser(description="Color Transformer Training Script") # Add arguments parser.add_argument( "--bs", type=int, default=64, help="Input batch size for training", ) parser.add_argument( "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" ) parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") parser.add_argument( "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" ) parser.add_argument( "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency" ) parser.add_argument( "-w", "--num_workers", type=int, default=3, help="Number of workers for data loading", ) # Parse arguments args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() early_stop_callback = EarlyStopping( monitor="hp_metric", # Metric to monitor min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement patience=12, # Number of epochs with no improvement after which training will be stopped mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. verbose=True, ) save_img_callback = SaveImageCallback( save_interval=1, final_dir="out", ) # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed train_dataloader = create_named_dataloader( N=0, batch_size=args.bs, shuffle=True, num_workers=args.num_workers, ) params = argparse.Namespace( alpha=args.alpha, learning_rate=args.lr, batch_size=args.bs, ) # Initialize model with parsed arguments model = ColorTransformerModel(params) # Initialize trainer with parsed arguments trainer = pl.Trainer( callbacks=[early_stop_callback, save_img_callback], max_epochs=args.max_epochs, log_every_n_steps=args.log_every_n_steps, ) # Train the model trainer.fit(model, train_dataloader)