diff --git a/main.py b/main.py index a40256c..a78f21b 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ import pytorch_lightning as pl from dataloader import create_named_dataloader from model import ColorTransformerModel +from pytorch_lightning.callbacks import EarlyStopping def parse_args(): @@ -42,6 +43,13 @@ def parse_args(): if __name__ == "__main__": args = parse_args() + early_stop_callback = EarlyStopping( + monitor='hp_metric', # Metric to monitor + min_delta=0, # Minimum change in the monitored quantity to qualify as an improvement + patience=50, # Number of epochs with no improvement after which training will be stopped + mode='min', # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. + verbose=True, + ) # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed @@ -63,6 +71,7 @@ if __name__ == "__main__": # Initialize trainer with parsed arguments trainer = pl.Trainer( + callbacks=[early_stop_callback], max_epochs=args.max_epochs, log_every_n_steps=args.log_every_n_steps, ) diff --git a/search.py b/search.py index 7008cfa..5a3adf6 100644 --- a/search.py +++ b/search.py @@ -5,7 +5,7 @@ from random import sample import numpy as np from lightning_sdk import Machine, Studio # noqa: F401 -NUM_JOBS = 64 +NUM_JOBS = 100 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name