Michael Pilosov, PhD
10 months ago
5 changed files with 138 additions and 110 deletions
@ -0,0 +1,77 @@ |
|||||
|
import lightning as L |
||||
|
import torch |
||||
|
from matplotlib.colors import rgb_to_hsv |
||||
|
from torch.utils.data import DataLoader |
||||
|
|
||||
|
from utils import extract_colors, preprocess_data |
||||
|
|
||||
|
|
||||
|
class ColorDataModule(L.LightningDataModule): |
||||
|
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): |
||||
|
super().__init__() |
||||
|
self.val_size = val_size |
||||
|
self.batch_size = batch_size |
||||
|
self.train_size = train_size |
||||
|
|
||||
|
def prepare_data(self): |
||||
|
# no state. called from main process. |
||||
|
pass |
||||
|
|
||||
|
def setup(self, stage: str): |
||||
|
rgb_tensor, xkcd_color_names = extract_colors() |
||||
|
rgb_tensor = preprocess_data(rgb_tensor, skip=True) |
||||
|
self.xkcd_colors = [ |
||||
|
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) |
||||
|
for i in range(len(rgb_tensor)) |
||||
|
] |
||||
|
if self.train_size > 0: |
||||
|
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) |
||||
|
train_rgb = preprocess_data(train_rgb, skip=True) |
||||
|
self.train_colors = [ |
||||
|
( |
||||
|
train_rgb[i], |
||||
|
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32), |
||||
|
) |
||||
|
for i in range(len(train_rgb)) |
||||
|
] |
||||
|
|
||||
|
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) |
||||
|
val_rgb = preprocess_data(val_rgb, skip=True) |
||||
|
self.random_colors = [ |
||||
|
( |
||||
|
val_rgb[i], |
||||
|
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32), |
||||
|
) |
||||
|
for i in range(len(val_rgb)) |
||||
|
] |
||||
|
|
||||
|
# Assign train/val datasets for use in dataloaders |
||||
|
if stage == "fit": |
||||
|
self.color_val = self.random_colors |
||||
|
if self.train_size > 0: |
||||
|
self.color_train = self.train_colors |
||||
|
else: |
||||
|
self.color_train = self.xkcd_colors |
||||
|
|
||||
|
# Assign test dataset for use in dataloader(s) |
||||
|
if stage == "test": |
||||
|
self.color_test = self.random_colors |
||||
|
|
||||
|
if stage == "predict": # for visualizing |
||||
|
self.color_predict = self.xkcd_colors |
||||
|
|
||||
|
def train_dataloader(self): |
||||
|
return DataLoader(self.color_train, batch_size=self.batch_size) |
||||
|
|
||||
|
def val_dataloader(self): |
||||
|
return DataLoader(self.color_val, batch_size=self.batch_size) |
||||
|
|
||||
|
def test_dataloader(self): |
||||
|
return DataLoader(self.color_test, batch_size=self.batch_size) |
||||
|
|
||||
|
def predict_dataloader(self): |
||||
|
return DataLoader(self.color_predict, batch_size=self.batch_size) |
||||
|
|
||||
|
def teardown(self, stage: str): |
||||
|
# Used to clean-up when the run is finished |
||||
|
pass |
@ -0,0 +1,28 @@ |
|||||
|
from lightning.pytorch.cli import LightningCLI |
||||
|
|
||||
|
# from callbacks import SaveImageCallback |
||||
|
from datamodule import ColorDataModule |
||||
|
from model import ColorTransformerModel |
||||
|
|
||||
|
|
||||
|
def cli_main(): |
||||
|
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841 |
||||
|
# note: don't call fit!! |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
cli_main() |
||||
|
# note: it is good practice to implement the CLI in a function and call it in the main if block |
||||
|
|
||||
|
|
||||
|
# save_img_callback = SaveImageCallback( |
||||
|
# save_interval=0, |
||||
|
# final_dir="out", |
||||
|
# ) |
||||
|
|
||||
|
# trainer = pl.Trainer( |
||||
|
# callbacks=[save_img_callback], |
||||
|
# ) |
||||
|
|
||||
|
# data = ColorDataModule() |
||||
|
# trainer.fit(model, data) |
Loading…
Reference in new issue