Browse Source

fine-tuning

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
d2b66fae56
  1. 9
      main.py
  2. 2
      search.py

9
main.py

@ -4,6 +4,7 @@ import pytorch_lightning as pl
from dataloader import create_named_dataloader
from model import ColorTransformerModel
from pytorch_lightning.callbacks import EarlyStopping
def parse_args():
@ -42,6 +43,13 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
early_stop_callback = EarlyStopping(
monitor='hp_metric', # Metric to monitor
min_delta=0, # Minimum change in the monitored quantity to qualify as an improvement
patience=50, # Number of epochs with no improvement after which training will be stopped
mode='min', # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True,
)
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
@ -63,6 +71,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
callbacks=[early_stop_callback],
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)

2
search.py

@ -5,7 +5,7 @@ from random import sample
import numpy as np
from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 64
NUM_JOBS = 100
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name

Loading…
Cancel
Save