From e1ac3211b90117202093084137cb6126d80aed6c Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sat, 27 Jan 2024 09:14:14 +0000 Subject: [PATCH] use lightning CLI everywhere --- callbacks.py | 4 +-- datamodule.py | 91 ++++++++++++++++++++++++++++++++------------------- losses.py | 4 ++- makefile | 32 ++++++++++++++++-- model.py | 5 +-- newsearch.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 178 insertions(+), 40 deletions(-) create mode 100644 newsearch.py diff --git a/callbacks.py b/callbacks.py index 7ecfcc5..f17da49 100644 --- a/callbacks.py +++ b/callbacks.py @@ -1,11 +1,11 @@ from pathlib import Path -import pytorch_lightning as pl +from lightning import Callback from check import create_circle -class SaveImageCallback(pl.Callback): +class SaveImageCallback(Callback): def __init__(self, save_interval=1, final_dir: str = None): self.save_interval = save_interval self.final_dir = final_dir diff --git a/datamodule.py b/datamodule.py index 5d9b3ba..f35c74c 100644 --- a/datamodule.py +++ b/datamodule.py @@ -7,71 +7,96 @@ from utils import extract_colors, preprocess_data class ColorDataModule(L.LightningDataModule): - def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): + def __init__( + self, + val_size: int = 10_000, + train_size=0, + batch_size: int = 32, + num_workers: int = 3, + ): super().__init__() self.val_size = val_size - self.batch_size = batch_size self.train_size = train_size + self.batch_size = batch_size + self.num_workers = num_workers def prepare_data(self): # no state. called from main process. pass - def setup(self, stage: str): + @classmethod + def get_hue(cls, v: torch.Tensor) -> torch.Tensor: + return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32) + + @classmethod + def get_random_colors(cls, size: int): + train_rgb = torch.rand((int(size), 3), dtype=torch.float32) + train_rgb = preprocess_data(train_rgb, skip=True) + return [(c, cls.get_hue(c)) for c in train_rgb] + + @classmethod + def get_xkcd_colors(cls): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor, skip=True) - self.xkcd_colors = [ + return [ (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] - if self.train_size > 0: - train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) - train_rgb = preprocess_data(train_rgb, skip=True) - self.train_colors = [ - ( - train_rgb[i], - torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32), - ) - for i in range(len(train_rgb)) - ] - - val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) - val_rgb = preprocess_data(val_rgb, skip=True) - self.random_colors = [ - ( - val_rgb[i], - torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32), - ) - for i in range(len(val_rgb)) - ] + def setup(self, stage: str): # Assign train/val datasets for use in dataloaders if stage == "fit": - self.color_val = self.random_colors + self.color_val = self.get_random_colors(self.val_size) if self.train_size > 0: - self.color_train = self.train_colors + self.color_train = self.get_random_colors(self.train_size) else: - self.color_train = self.xkcd_colors + self.color_train = self.get_xkcd_colors() # Assign test dataset for use in dataloader(s) if stage == "test": - self.color_test = self.random_colors + self.color_test = self.get_random_colors(self.val_size) if stage == "predict": # for visualizing - self.color_predict = self.xkcd_colors + self.color_predict = self.get_xkcd_colors() def train_dataloader(self): - return DataLoader(self.color_train, batch_size=self.batch_size) + return DataLoader( + self.color_train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) def val_dataloader(self): - return DataLoader(self.color_val, batch_size=self.batch_size) + return DataLoader( + self.color_val, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) def test_dataloader(self): - return DataLoader(self.color_test, batch_size=self.batch_size) + return DataLoader( + self.color_test, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) def predict_dataloader(self): - return DataLoader(self.color_predict, batch_size=self.batch_size) + return DataLoader( + self.color_predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) def teardown(self, stage: str): # Used to clean-up when the run is finished pass + + +if __name__ == "__main__": + cdm = ColorDataModule() + cdm.setup("train") + print(cdm) diff --git a/losses.py b/losses.py index 3de252f..eb0de91 100644 --- a/losses.py +++ b/losses.py @@ -18,7 +18,7 @@ from utils import PURE_RGB def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): - # Distance Preservation Component + # Distance Preservation Component (or scaled euclidean if given targets) # Encourages the model to keep relative distances from the RGB space in the transformed space if target_inputs is None: target_inputs = inputs @@ -51,6 +51,7 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): def separation_loss(red, green, blue): # Separation Component + # TODO: remove # Encourages the model to keep R, G, B values equally separated in the transformed space red_loss = torch.abs(0 - red) green_loss = torch.abs(1 / 3 - green) / (2 / 3) @@ -59,6 +60,7 @@ def separation_loss(red, green, blue): def calculate_separation_loss(model): + # TODO: remove # Wrapper function to calculate separation loss outputs = model(PURE_RGB.to(model.device)) red, green, blue = outputs[0], outputs[1], outputs[2] diff --git a/makefile b/makefile index 254dae9..36bb623 100644 --- a/makefile +++ b/makefile @@ -4,10 +4,38 @@ lint: flake8 --ignore E501,W503 *.py test: - python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048 + # python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048 + python newmain.py fit \ + --seed_everything 21 \ + --data.batch_size 256 \ + --data.train_size 0 \ + --data.val_size 100000 \ + --model.alpha 0 \ + --model.width 2048 \ + --trainer.fast_dev_run 1 \ + --trainer.min_epochs 1 \ + --trainer.max_epochs 10 \ + --trainer.check_val_every_n_epoch 1 \ + --trainer.callbacks callbacks.SaveImageCallback \ + --trainer.callbacks.init_args.final_dir out \ + --trainer.callbacks.init_args.save_interval 0 \ + --optimizer torch.optim.Adam \ + --optimizer.init_args.lr 0.01 \ + --lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \ + --lr_scheduler.init_args.patience 5 \ + --lr_scheduler.init_args.cooldown 10 \ + --lr_scheduler.init_args.factor 0.05 \ + --lr_scheduler.init_args.verbose true \ + --print_config + + +help: + # python newmain.py fit --help --trainer.callbacks.help + # python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau + python newmain.py fit --help search: - python search.py + python newsearch.py hsv: python hsv.py diff --git a/model.py b/model.py index e3e2276..b0ec076 100644 --- a/model.py +++ b/model.py @@ -14,6 +14,7 @@ class ColorTransformerModel(L.LightningModule): width: int = 128, depth: int = 1, bias: bool = False, + alpha: float = 0, ): super().__init__() self.save_hyperparameters() @@ -64,7 +65,7 @@ class ColorTransformerModel(L.LightningModule): self.log("s_loss", s_loss) return loss - def validation_step(self): + def validation_step(self, batch): inputs, labels = batch # these are true HSV labels - no learning allowed. outputs = self.forward(inputs) distance = torch.minimum( @@ -79,7 +80,7 @@ class ColorTransformerModel(L.LightningModule): def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), - lr=self.hparams.learning_rate, + lr=0.1, ) lr_scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True diff --git a/newsearch.py b/newsearch.py new file mode 100644 index 0000000..53269be --- /dev/null +++ b/newsearch.py @@ -0,0 +1,82 @@ +import subprocess +import sys +from random import sample + +import numpy as np # noqa: F401 +from lightning_sdk import Machine, Studio # noqa: F401 + +NUM_JOBS = 100 + +# reference to the current studio +# if you run outside of Lightning, you can pass the Studio name +# studio = Studio() + +# use the jobs plugin +# studio.install_plugin("jobs") +# job_plugin = studio.installed_plugins["jobs"] + +# do a sweep over learning rates + +# Define the ranges or sets of values for each hyperparameter +# alpha_values = list(np.round(np.linspace(2, 4, 21), 4)) +# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5)) +learning_rate_values = [1e-2] +alpha_values = [0, 1, 2] +widths = [2**k for k in range(4, 15)] +# learning_rate_values = [5e-4] +batch_size_values = [256] +max_epochs_values = [100] +seeds = list(range(21, 1992)) + +# Generate all possible combinations of hyperparameters +all_params = [ + (alpha, lr, bs, me, s, w) + for alpha in alpha_values + for lr in learning_rate_values + for bs in batch_size_values + for me in max_epochs_values + for s in seeds + for w in widths +] + + +# perform random search with a limit +search_params = sample(all_params, min(NUM_JOBS, len(all_params))) + +for idx, params in enumerate(search_params): + a, lr, bs, me, s, w = params + cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}" + cmd = f""" + python newmain.py fit \ + --seed_everything {s} \ + --data.batch_size {bs} \ + --data.train_size 0 \ + --data.val_size 100000 \ + --model.alpha {a} \ + --model.width {w} \ + --trainer.fast_dev_run 1 \ + --trainer.min_epochs 10 \ + --trainer.max_epochs {me} \ + --trainer.check_val_every_n_epoch 1 \ + --trainer.callbacks callbacks.SaveImageCallback \ + --trainer.callbacks.init_args.final_dir out \ + --trainer.callbacks.init_args.save_interval 0 \ + --optimizer torch.optim.Adam \ + --optimizer.init_args.lr {lr} \ + --lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \ + --lr_scheduler.init_args.patience 5 \ + --lr_scheduler.init_args.cooldown 10 \ + --lr_scheduler.init_args.factor 0.05 \ + --lr_scheduler.init_args.verbose true \ + --print_config + """ + + # job_name = f"color2_{bs}_{a}_{lr:2.2e}" + # job_plugin.run(cmd, machine=Machine.T4, name=job_name) + print(f"Running {params}: {cmd}") + try: + # Run the command and wait for it to complete + subprocess.run(cmd, shell=True, check=True) + except KeyboardInterrupt: + print("Interrupted by user") + sys.exit(1)