You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
3.1 KiB

11 months ago
import argparse
10 months ago
import random
11 months ago
10 months ago
import numpy as np
11 months ago
import pytorch_lightning as pl
10 months ago
import torch
10 months ago
from pytorch_lightning.callbacks import EarlyStopping
11 months ago
10 months ago
from callbacks import SaveImageCallback
from dataloader import create_named_dataloader
11 months ago
from model import ColorTransformerModel
11 months ago
11 months ago
def parse_args():
11 months ago
# Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
# Add arguments
parser.add_argument(
11 months ago
"--bs",
11 months ago
type=int,
default=64,
help="Input batch size for training",
)
parser.add_argument(
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
)
11 months ago
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
11 months ago
parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
)
parser.add_argument(
11 months ago
"-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
11 months ago
)
10 months ago
parser.add_argument("--seed", default=21, type=int, help="Seed")
11 months ago
parser.add_argument(
11 months ago
"-w",
"--num_workers",
type=int,
default=3,
help="Number of workers for data loading",
11 months ago
)
10 months ago
parser.add_argument("--width", type=int, default=128, help="Max width of network.")
11 months ago
# Parse arguments
args = parser.parse_args()
11 months ago
return args
10 months ago
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
11 months ago
if __name__ == "__main__":
11 months ago
args = parse_args()
10 months ago
seed_everything(args.seed)
10 months ago
early_stop_callback = EarlyStopping(
10 months ago
monitor="hp_metric", # Metric to monitor
10 months ago
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
10 months ago
patience=24, # Number of epochs with no improvement after which training will be stopped
10 months ago
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
10 months ago
verbose=True,
)
11 months ago
10 months ago
save_img_callback = SaveImageCallback(
10 months ago
save_interval=0,
10 months ago
final_dir="out",
)
11 months ago
# Initialize data loader with parsed arguments
11 months ago
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_named_dataloader(
N=0,
11 months ago
batch_size=args.bs,
11 months ago
shuffle=True,
num_workers=args.num_workers,
)
11 months ago
params = argparse.Namespace(
11 months ago
alpha=args.alpha,
11 months ago
learning_rate=args.lr,
11 months ago
batch_size=args.bs,
width=args.width,
11 months ago
)
11 months ago
# Initialize model with parsed arguments
model = ColorTransformerModel(params)
11 months ago
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
10 months ago
deterministic=True,
10 months ago
callbacks=[early_stop_callback, save_img_callback],
11 months ago
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)
# Train the model
trainer.fit(model, train_dataloader)