You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.2 KiB
110 lines
3.2 KiB
import argparse
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
|
|
|
|
from callbacks import SaveImageCallback
|
|
from dataloader import create_named_dataloader as create_dataloader
|
|
from model import ColorTransformerModel
|
|
|
|
|
|
def parse_args():
|
|
# Define argument parser
|
|
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
|
|
|
|
# Add arguments
|
|
parser.add_argument(
|
|
"--bs",
|
|
type=int,
|
|
default=64,
|
|
help="Input batch size for training",
|
|
)
|
|
parser.add_argument(
|
|
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
|
|
)
|
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
|
|
parser.add_argument(
|
|
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
|
|
)
|
|
parser.add_argument(
|
|
"-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
|
|
)
|
|
parser.add_argument("--seed", default=21, type=int, help="Seed")
|
|
parser.add_argument(
|
|
"-w",
|
|
"--num_workers",
|
|
type=int,
|
|
default=3,
|
|
help="Number of workers for data loading",
|
|
)
|
|
parser.add_argument("--width", type=int, default=128, help="Max width of network.")
|
|
|
|
# Parse arguments
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def seed_everything(seed=42):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
seed_everything(args.seed)
|
|
|
|
# early_stop_callback = EarlyStopping(
|
|
# monitor="hp_metric", # Metric to monitor
|
|
# min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
|
|
# patience=5, # Number of epochs with no improvement after which training will be stopped
|
|
# mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
|
|
# verbose=True,
|
|
# )
|
|
|
|
save_img_callback = SaveImageCallback(
|
|
save_interval=0,
|
|
final_dir="out",
|
|
)
|
|
|
|
# Initialize data loader with parsed arguments
|
|
# named_data_loader also has grayscale extras. TODO: remove unnamed
|
|
train_dataloader = create_dataloader(
|
|
# N=1e5,
|
|
skip=True,
|
|
batch_size=args.bs,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
)
|
|
|
|
params = argparse.Namespace(
|
|
alpha=args.alpha,
|
|
learning_rate=args.lr,
|
|
batch_size=args.bs,
|
|
width=args.width,
|
|
bias=False,
|
|
transform="relu",
|
|
depth=1,
|
|
)
|
|
|
|
# Initialize model with parsed arguments
|
|
model = ColorTransformerModel(params)
|
|
|
|
# Initialize trainer with parsed arguments
|
|
trainer = pl.Trainer(
|
|
deterministic=True,
|
|
callbacks=[save_img_callback],
|
|
max_epochs=args.max_epochs,
|
|
log_every_n_steps=args.log_every_n_steps,
|
|
)
|
|
|
|
# Train the model
|
|
trainer.fit(model, train_dataloader)
|
|
|