import lightning as L import torch from matplotlib.colors import rgb_to_hsv from torch.utils.data import DataLoader from utils import extract_colors, preprocess_data class ColorDataModule(L.LightningDataModule): def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): super().__init__() self.val_size = val_size self.batch_size = batch_size self.train_size = train_size def prepare_data(self): # no state. called from main process. pass def setup(self, stage: str): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor, skip=True) self.xkcd_colors = [ (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] if self.train_size > 0: train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) train_rgb = preprocess_data(train_rgb, skip=True) self.train_colors = [ ( train_rgb[i], torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32), ) for i in range(len(train_rgb)) ] val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) val_rgb = preprocess_data(val_rgb, skip=True) self.random_colors = [ ( val_rgb[i], torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32), ) for i in range(len(val_rgb)) ] # Assign train/val datasets for use in dataloaders if stage == "fit": self.color_val = self.random_colors if self.train_size > 0: self.color_train = self.train_colors else: self.color_train = self.xkcd_colors # Assign test dataset for use in dataloader(s) if stage == "test": self.color_test = self.random_colors if stage == "predict": # for visualizing self.color_predict = self.xkcd_colors def train_dataloader(self): return DataLoader(self.color_train, batch_size=self.batch_size) def val_dataloader(self): return DataLoader(self.color_val, batch_size=self.batch_size) def test_dataloader(self): return DataLoader(self.color_test, batch_size=self.batch_size) def predict_dataloader(self): return DataLoader(self.color_predict, batch_size=self.batch_size) def teardown(self, stage: str): # Used to clean-up when the run is finished pass