colors/main.py

108 lines
3.1 KiB
Python
Raw Normal View History

2023-12-30 04:37:06 +00:00
import argparse
2024-01-15 04:25:29 +00:00
import random
2023-12-30 04:37:06 +00:00
2024-01-15 04:25:29 +00:00
import numpy as np
2023-12-30 04:37:06 +00:00
import pytorch_lightning as pl
2024-01-15 04:25:29 +00:00
import torch
2024-01-16 04:37:22 +00:00
from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
2023-12-30 04:37:06 +00:00
2024-01-15 02:58:41 +00:00
from callbacks import SaveImageCallback
2024-01-16 04:37:22 +00:00
from dataloader import create_named_dataloader as create_dataloader
2023-12-30 04:37:06 +00:00
from model import ColorTransformerModel
2023-12-30 05:30:52 +00:00
2023-12-30 05:13:50 +00:00
def parse_args():
2023-12-30 04:37:06 +00:00
# Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
# Add arguments
parser.add_argument(
2023-12-30 05:13:50 +00:00
"--bs",
2023-12-30 04:37:06 +00:00
type=int,
default=64,
help="Input batch size for training",
)
parser.add_argument(
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
)
2023-12-30 05:30:52 +00:00
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
2023-12-30 04:37:06 +00:00
parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
)
parser.add_argument(
2023-12-30 05:13:50 +00:00
"-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
2023-12-30 04:37:06 +00:00
)
2024-01-15 04:25:29 +00:00
parser.add_argument("--seed", default=21, type=int, help="Seed")
2023-12-30 04:37:06 +00:00
parser.add_argument(
2023-12-30 05:30:52 +00:00
"-w",
"--num_workers",
type=int,
default=3,
help="Number of workers for data loading",
2023-12-30 04:37:06 +00:00
)
2024-01-15 19:02:26 +00:00
parser.add_argument("--width", type=int, default=128, help="Max width of network.")
2023-12-30 04:37:06 +00:00
# Parse arguments
args = parser.parse_args()
2023-12-30 05:13:50 +00:00
return args
2024-01-15 04:25:29 +00:00
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
2023-12-30 05:30:52 +00:00
if __name__ == "__main__":
2023-12-30 05:13:50 +00:00
args = parse_args()
2024-01-15 04:25:29 +00:00
seed_everything(args.seed)
2024-01-16 04:37:22 +00:00
# early_stop_callback = EarlyStopping(
# monitor="hp_metric", # Metric to monitor
# min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
# patience=5, # Number of epochs with no improvement after which training will be stopped
# mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
# verbose=True,
# )
2023-12-30 04:37:06 +00:00
2024-01-15 02:58:41 +00:00
save_img_callback = SaveImageCallback(
2024-01-15 20:11:54 +00:00
save_interval=0,
2024-01-16 04:37:22 +00:00
final_dir="out",
2024-01-15 02:58:41 +00:00
)
2023-12-30 04:37:06 +00:00
# Initialize data loader with parsed arguments
2023-12-31 07:02:43 +00:00
# named_data_loader also has grayscale extras. TODO: remove unnamed
2024-01-15 19:18:28 +00:00
train_dataloader = create_dataloader(
2024-01-16 04:37:22 +00:00
# N=1e5,
skip=False,
2023-12-30 05:13:50 +00:00
batch_size=args.bs,
2023-12-30 04:37:06 +00:00
shuffle=True,
num_workers=args.num_workers,
)
2023-12-31 06:17:15 +00:00
params = argparse.Namespace(
2023-12-30 04:37:06 +00:00
alpha=args.alpha,
2023-12-30 05:13:50 +00:00
learning_rate=args.lr,
2023-12-31 06:17:15 +00:00
batch_size=args.bs,
2024-01-15 06:35:48 +00:00
width=args.width,
2023-12-30 04:37:06 +00:00
)
2023-12-31 06:17:15 +00:00
# Initialize model with parsed arguments
model = ColorTransformerModel(params)
2023-12-30 04:37:06 +00:00
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
2024-01-15 04:25:29 +00:00
deterministic=True,
2024-01-16 04:37:22 +00:00
callbacks=[save_img_callback],
2023-12-30 04:37:06 +00:00
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)
# Train the model
trainer.fit(model, train_dataloader)