Michael Pilosov, PhD
10 months ago
5 changed files with 138 additions and 110 deletions
@ -0,0 +1,77 @@ |
|||
import lightning as L |
|||
import torch |
|||
from matplotlib.colors import rgb_to_hsv |
|||
from torch.utils.data import DataLoader |
|||
|
|||
from utils import extract_colors, preprocess_data |
|||
|
|||
|
|||
class ColorDataModule(L.LightningDataModule): |
|||
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): |
|||
super().__init__() |
|||
self.val_size = val_size |
|||
self.batch_size = batch_size |
|||
self.train_size = train_size |
|||
|
|||
def prepare_data(self): |
|||
# no state. called from main process. |
|||
pass |
|||
|
|||
def setup(self, stage: str): |
|||
rgb_tensor, xkcd_color_names = extract_colors() |
|||
rgb_tensor = preprocess_data(rgb_tensor, skip=True) |
|||
self.xkcd_colors = [ |
|||
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) |
|||
for i in range(len(rgb_tensor)) |
|||
] |
|||
if self.train_size > 0: |
|||
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) |
|||
train_rgb = preprocess_data(train_rgb, skip=True) |
|||
self.train_colors = [ |
|||
( |
|||
train_rgb[i], |
|||
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32), |
|||
) |
|||
for i in range(len(train_rgb)) |
|||
] |
|||
|
|||
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) |
|||
val_rgb = preprocess_data(val_rgb, skip=True) |
|||
self.random_colors = [ |
|||
( |
|||
val_rgb[i], |
|||
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32), |
|||
) |
|||
for i in range(len(val_rgb)) |
|||
] |
|||
|
|||
# Assign train/val datasets for use in dataloaders |
|||
if stage == "fit": |
|||
self.color_val = self.random_colors |
|||
if self.train_size > 0: |
|||
self.color_train = self.train_colors |
|||
else: |
|||
self.color_train = self.xkcd_colors |
|||
|
|||
# Assign test dataset for use in dataloader(s) |
|||
if stage == "test": |
|||
self.color_test = self.random_colors |
|||
|
|||
if stage == "predict": # for visualizing |
|||
self.color_predict = self.xkcd_colors |
|||
|
|||
def train_dataloader(self): |
|||
return DataLoader(self.color_train, batch_size=self.batch_size) |
|||
|
|||
def val_dataloader(self): |
|||
return DataLoader(self.color_val, batch_size=self.batch_size) |
|||
|
|||
def test_dataloader(self): |
|||
return DataLoader(self.color_test, batch_size=self.batch_size) |
|||
|
|||
def predict_dataloader(self): |
|||
return DataLoader(self.color_predict, batch_size=self.batch_size) |
|||
|
|||
def teardown(self, stage: str): |
|||
# Used to clean-up when the run is finished |
|||
pass |
@ -0,0 +1,28 @@ |
|||
from lightning.pytorch.cli import LightningCLI |
|||
|
|||
# from callbacks import SaveImageCallback |
|||
from datamodule import ColorDataModule |
|||
from model import ColorTransformerModel |
|||
|
|||
|
|||
def cli_main(): |
|||
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841 |
|||
# note: don't call fit!! |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
cli_main() |
|||
# note: it is good practice to implement the CLI in a function and call it in the main if block |
|||
|
|||
|
|||
# save_img_callback = SaveImageCallback( |
|||
# save_interval=0, |
|||
# final_dir="out", |
|||
# ) |
|||
|
|||
# trainer = pl.Trainer( |
|||
# callbacks=[save_img_callback], |
|||
# ) |
|||
|
|||
# data = ColorDataModule() |
|||
# trainer.fit(model, data) |
Loading…
Reference in new issue