Browse Source

use lightning CLI everywhere

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
e1ac3211b9
  1. 4
      callbacks.py
  2. 91
      datamodule.py
  3. 4
      losses.py
  4. 32
      makefile
  5. 5
      model.py
  6. 82
      newsearch.py

4
callbacks.py

@ -1,11 +1,11 @@
from pathlib import Path from pathlib import Path
import pytorch_lightning as pl from lightning import Callback
from check import create_circle from check import create_circle
class SaveImageCallback(pl.Callback): class SaveImageCallback(Callback):
def __init__(self, save_interval=1, final_dir: str = None): def __init__(self, save_interval=1, final_dir: str = None):
self.save_interval = save_interval self.save_interval = save_interval
self.final_dir = final_dir self.final_dir = final_dir

91
datamodule.py

@ -7,71 +7,96 @@ from utils import extract_colors, preprocess_data
class ColorDataModule(L.LightningDataModule): class ColorDataModule(L.LightningDataModule):
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): def __init__(
self,
val_size: int = 10_000,
train_size=0,
batch_size: int = 32,
num_workers: int = 3,
):
super().__init__() super().__init__()
self.val_size = val_size self.val_size = val_size
self.batch_size = batch_size
self.train_size = train_size self.train_size = train_size
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self): def prepare_data(self):
# no state. called from main process. # no state. called from main process.
pass pass
def setup(self, stage: str): @classmethod
def get_hue(cls, v: torch.Tensor) -> torch.Tensor:
return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32)
@classmethod
def get_random_colors(cls, size: int):
train_rgb = torch.rand((int(size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
return [(c, cls.get_hue(c)) for c in train_rgb]
@classmethod
def get_xkcd_colors(cls):
rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True) rgb_tensor = preprocess_data(rgb_tensor, skip=True)
self.xkcd_colors = [ return [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor)) for i in range(len(rgb_tensor))
] ]
if self.train_size > 0:
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
self.train_colors = [
(
train_rgb[i],
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(train_rgb))
]
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
val_rgb = preprocess_data(val_rgb, skip=True)
self.random_colors = [
(
val_rgb[i],
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(val_rgb))
]
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders # Assign train/val datasets for use in dataloaders
if stage == "fit": if stage == "fit":
self.color_val = self.random_colors self.color_val = self.get_random_colors(self.val_size)
if self.train_size > 0: if self.train_size > 0:
self.color_train = self.train_colors self.color_train = self.get_random_colors(self.train_size)
else: else:
self.color_train = self.xkcd_colors self.color_train = self.get_xkcd_colors()
# Assign test dataset for use in dataloader(s) # Assign test dataset for use in dataloader(s)
if stage == "test": if stage == "test":
self.color_test = self.random_colors self.color_test = self.get_random_colors(self.val_size)
if stage == "predict": # for visualizing if stage == "predict": # for visualizing
self.color_predict = self.xkcd_colors self.color_predict = self.get_xkcd_colors()
def train_dataloader(self): def train_dataloader(self):
return DataLoader(self.color_train, batch_size=self.batch_size) return DataLoader(
self.color_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(self.color_val, batch_size=self.batch_size) return DataLoader(
self.color_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self): def test_dataloader(self):
return DataLoader(self.color_test, batch_size=self.batch_size) return DataLoader(
self.color_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def predict_dataloader(self): def predict_dataloader(self):
return DataLoader(self.color_predict, batch_size=self.batch_size) return DataLoader(
self.color_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def teardown(self, stage: str): def teardown(self, stage: str):
# Used to clean-up when the run is finished # Used to clean-up when the run is finished
pass pass
if __name__ == "__main__":
cdm = ColorDataModule()
cdm.setup("train")
print(cdm)

4
losses.py

@ -18,7 +18,7 @@ from utils import PURE_RGB
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component # Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None: if target_inputs is None:
target_inputs = inputs target_inputs = inputs
@ -51,6 +51,7 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
def separation_loss(red, green, blue): def separation_loss(red, green, blue):
# Separation Component # Separation Component
# TODO: remove
# Encourages the model to keep R, G, B values equally separated in the transformed space # Encourages the model to keep R, G, B values equally separated in the transformed space
red_loss = torch.abs(0 - red) red_loss = torch.abs(0 - red)
green_loss = torch.abs(1 / 3 - green) / (2 / 3) green_loss = torch.abs(1 / 3 - green) / (2 / 3)
@ -59,6 +60,7 @@ def separation_loss(red, green, blue):
def calculate_separation_loss(model): def calculate_separation_loss(model):
# TODO: remove
# Wrapper function to calculate separation loss # Wrapper function to calculate separation loss
outputs = model(PURE_RGB.to(model.device)) outputs = model(PURE_RGB.to(model.device))
red, green, blue = outputs[0], outputs[1], outputs[2] red, green, blue = outputs[0], outputs[1], outputs[2]

32
makefile

@ -4,10 +4,38 @@ lint:
flake8 --ignore E501,W503 *.py flake8 --ignore E501,W503 *.py
test: test:
python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048 # python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048
python newmain.py fit \
--seed_everything 21 \
--data.batch_size 256 \
--data.train_size 0 \
--data.val_size 100000 \
--model.alpha 0 \
--model.width 2048 \
--trainer.fast_dev_run 1 \
--trainer.min_epochs 1 \
--trainer.max_epochs 10 \
--trainer.check_val_every_n_epoch 1 \
--trainer.callbacks callbacks.SaveImageCallback \
--trainer.callbacks.init_args.final_dir out \
--trainer.callbacks.init_args.save_interval 0 \
--optimizer torch.optim.Adam \
--optimizer.init_args.lr 0.01 \
--lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
--lr_scheduler.init_args.patience 5 \
--lr_scheduler.init_args.cooldown 10 \
--lr_scheduler.init_args.factor 0.05 \
--lr_scheduler.init_args.verbose true \
--print_config
help:
# python newmain.py fit --help --trainer.callbacks.help
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
python newmain.py fit --help
search: search:
python search.py python newsearch.py
hsv: hsv:
python hsv.py python hsv.py

5
model.py

@ -14,6 +14,7 @@ class ColorTransformerModel(L.LightningModule):
width: int = 128, width: int = 128,
depth: int = 1, depth: int = 1,
bias: bool = False, bias: bool = False,
alpha: float = 0,
): ):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
@ -64,7 +65,7 @@ class ColorTransformerModel(L.LightningModule):
self.log("s_loss", s_loss) self.log("s_loss", s_loss)
return loss return loss
def validation_step(self): def validation_step(self, batch):
inputs, labels = batch # these are true HSV labels - no learning allowed. inputs, labels = batch # these are true HSV labels - no learning allowed.
outputs = self.forward(inputs) outputs = self.forward(inputs)
distance = torch.minimum( distance = torch.minimum(
@ -79,7 +80,7 @@ class ColorTransformerModel(L.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
self.parameters(), self.parameters(),
lr=self.hparams.learning_rate, lr=0.1,
) )
lr_scheduler = ReduceLROnPlateau( lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True

82
newsearch.py

@ -0,0 +1,82 @@
import subprocess
import sys
from random import sample
import numpy as np # noqa: F401
from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 100
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
# studio = Studio()
# use the jobs plugin
# studio.install_plugin("jobs")
# job_plugin = studio.installed_plugins["jobs"]
# do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-2]
alpha_values = [0, 1, 2]
widths = [2**k for k in range(4, 15)]
# learning_rate_values = [5e-4]
batch_size_values = [256]
max_epochs_values = [100]
seeds = list(range(21, 1992))
# Generate all possible combinations of hyperparameters
all_params = [
(alpha, lr, bs, me, s, w)
for alpha in alpha_values
for lr in learning_rate_values
for bs in batch_size_values
for me in max_epochs_values
for s in seeds
for w in widths
]
# perform random search with a limit
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
for idx, params in enumerate(search_params):
a, lr, bs, me, s, w = params
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}"
cmd = f"""
python newmain.py fit \
--seed_everything {s} \
--data.batch_size {bs} \
--data.train_size 0 \
--data.val_size 100000 \
--model.alpha {a} \
--model.width {w} \
--trainer.fast_dev_run 1 \
--trainer.min_epochs 10 \
--trainer.max_epochs {me} \
--trainer.check_val_every_n_epoch 1 \
--trainer.callbacks callbacks.SaveImageCallback \
--trainer.callbacks.init_args.final_dir out \
--trainer.callbacks.init_args.save_interval 0 \
--optimizer torch.optim.Adam \
--optimizer.init_args.lr {lr} \
--lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
--lr_scheduler.init_args.patience 5 \
--lr_scheduler.init_args.cooldown 10 \
--lr_scheduler.init_args.factor 0.05 \
--lr_scheduler.init_args.verbose true \
--print_config
"""
# job_name = f"color2_{bs}_{a}_{lr:2.2e}"
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
print(f"Running {params}: {cmd}")
try:
# Run the command and wait for it to complete
subprocess.run(cmd, shell=True, check=True)
except KeyboardInterrupt:
print("Interrupted by user")
sys.exit(1)
Loading…
Cancel
Save