Browse Source

fine-tuning

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
d2b66fae56
  1. 9
      main.py
  2. 2
      search.py

9
main.py

@ -4,6 +4,7 @@ import pytorch_lightning as pl
from dataloader import create_named_dataloader from dataloader import create_named_dataloader
from model import ColorTransformerModel from model import ColorTransformerModel
from pytorch_lightning.callbacks import EarlyStopping
def parse_args(): def parse_args():
@ -42,6 +43,13 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
early_stop_callback = EarlyStopping(
monitor='hp_metric', # Metric to monitor
min_delta=0, # Minimum change in the monitored quantity to qualify as an improvement
patience=50, # Number of epochs with no improvement after which training will be stopped
mode='min', # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True,
)
# Initialize data loader with parsed arguments # Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed # named_data_loader also has grayscale extras. TODO: remove unnamed
@ -63,6 +71,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments # Initialize trainer with parsed arguments
trainer = pl.Trainer( trainer = pl.Trainer(
callbacks=[early_stop_callback],
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps, log_every_n_steps=args.log_every_n_steps,
) )

2
search.py

@ -5,7 +5,7 @@ from random import sample
import numpy as np import numpy as np
from lightning_sdk import Machine, Studio # noqa: F401 from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 64 NUM_JOBS = 100
# reference to the current studio # reference to the current studio
# if you run outside of Lightning, you can pass the Studio name # if you run outside of Lightning, you can pass the Studio name

Loading…
Cancel
Save