|
@ -7,7 +7,7 @@ import torch |
|
|
from pytorch_lightning.callbacks import EarlyStopping |
|
|
from pytorch_lightning.callbacks import EarlyStopping |
|
|
|
|
|
|
|
|
from callbacks import SaveImageCallback |
|
|
from callbacks import SaveImageCallback |
|
|
from dataloader import create_named_dataloader |
|
|
from dataloader import create_dataloader |
|
|
from model import ColorTransformerModel |
|
|
from model import ColorTransformerModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -65,20 +65,20 @@ if __name__ == "__main__": |
|
|
early_stop_callback = EarlyStopping( |
|
|
early_stop_callback = EarlyStopping( |
|
|
monitor="hp_metric", # Metric to monitor |
|
|
monitor="hp_metric", # Metric to monitor |
|
|
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement |
|
|
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement |
|
|
patience=24, # Number of epochs with no improvement after which training will be stopped |
|
|
patience=5, # Number of epochs with no improvement after which training will be stopped |
|
|
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. |
|
|
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. |
|
|
verbose=True, |
|
|
verbose=True, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
save_img_callback = SaveImageCallback( |
|
|
save_img_callback = SaveImageCallback( |
|
|
save_interval=0, |
|
|
save_interval=1, |
|
|
final_dir="out", |
|
|
final_dir="out", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
# Initialize data loader with parsed arguments |
|
|
# Initialize data loader with parsed arguments |
|
|
# named_data_loader also has grayscale extras. TODO: remove unnamed |
|
|
# named_data_loader also has grayscale extras. TODO: remove unnamed |
|
|
train_dataloader = create_named_dataloader( |
|
|
train_dataloader = create_dataloader( |
|
|
N=0, |
|
|
N=1e8, |
|
|
batch_size=args.bs, |
|
|
batch_size=args.bs, |
|
|
shuffle=True, |
|
|
shuffle=True, |
|
|
num_workers=args.num_workers, |
|
|
num_workers=args.num_workers, |
|
|