You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.5 KiB

11 months ago
import argparse
import pytorch_lightning as pl
10 months ago
from pytorch_lightning.callbacks import EarlyStopping
11 months ago
10 months ago
from callbacks import SaveImageCallback
from dataloader import create_named_dataloader
11 months ago
from model import ColorTransformerModel
11 months ago
11 months ago
def parse_args():
11 months ago
# Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
# Add arguments
parser.add_argument(
11 months ago
"--bs",
11 months ago
type=int,
default=64,
help="Input batch size for training",
)
parser.add_argument(
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
)
11 months ago
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
11 months ago
parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
)
parser.add_argument(
11 months ago
"-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
11 months ago
)
parser.add_argument(
11 months ago
"-w",
"--num_workers",
type=int,
default=3,
help="Number of workers for data loading",
11 months ago
)
# Parse arguments
args = parser.parse_args()
11 months ago
return args
11 months ago
if __name__ == "__main__":
11 months ago
args = parse_args()
10 months ago
early_stop_callback = EarlyStopping(
10 months ago
monitor="hp_metric", # Metric to monitor
10 months ago
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
10 months ago
patience=24, # Number of epochs with no improvement after which training will be stopped
10 months ago
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
10 months ago
verbose=True,
)
11 months ago
10 months ago
save_img_callback = SaveImageCallback(
10 months ago
save_interval=1,
10 months ago
final_dir="out",
)
11 months ago
# Initialize data loader with parsed arguments
11 months ago
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_named_dataloader(
N=0,
11 months ago
batch_size=args.bs,
11 months ago
shuffle=True,
num_workers=args.num_workers,
)
11 months ago
params = argparse.Namespace(
11 months ago
alpha=args.alpha,
11 months ago
learning_rate=args.lr,
11 months ago
batch_size=args.bs,
11 months ago
)
11 months ago
# Initialize model with parsed arguments
model = ColorTransformerModel(params)
11 months ago
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
10 months ago
callbacks=[early_stop_callback, save_img_callback],
11 months ago
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)
# Train the model
trainer.fit(model, train_dataloader)