|
|
@ -4,6 +4,7 @@ import pytorch_lightning as pl |
|
|
|
|
|
|
|
from dataloader import create_named_dataloader |
|
|
|
from model import ColorTransformerModel |
|
|
|
from pytorch_lightning.callbacks import EarlyStopping |
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
@ -42,6 +43,13 @@ def parse_args(): |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_args() |
|
|
|
early_stop_callback = EarlyStopping( |
|
|
|
monitor='hp_metric', # Metric to monitor |
|
|
|
min_delta=0, # Minimum change in the monitored quantity to qualify as an improvement |
|
|
|
patience=50, # Number of epochs with no improvement after which training will be stopped |
|
|
|
mode='min', # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. |
|
|
|
verbose=True, |
|
|
|
) |
|
|
|
|
|
|
|
# Initialize data loader with parsed arguments |
|
|
|
# named_data_loader also has grayscale extras. TODO: remove unnamed |
|
|
@ -63,6 +71,7 @@ if __name__ == "__main__": |
|
|
|
|
|
|
|
# Initialize trainer with parsed arguments |
|
|
|
trainer = pl.Trainer( |
|
|
|
callbacks=[early_stop_callback], |
|
|
|
max_epochs=args.max_epochs, |
|
|
|
log_every_n_steps=args.log_every_n_steps, |
|
|
|
) |
|
|
|