import lightning as L import torch from matplotlib.colors import rgb_to_hsv from torch.utils.data import DataLoader from utils import extract_colors, preprocess_data class ColorDataModule(L.LightningDataModule): def __init__( self, val_size: int = 10_000, train_size=0, batch_size: int = 32, num_workers: int = 3, ): super().__init__() self.val_size = val_size self.train_size = train_size self.batch_size = batch_size self.num_workers = num_workers def prepare_data(self): # no state. called from main process. pass @classmethod def get_hue(cls, v: torch.Tensor) -> torch.Tensor: return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32) @classmethod def get_random_colors(cls, size: int): train_rgb = torch.rand((int(size), 3), dtype=torch.float32) train_rgb = preprocess_data(train_rgb, skip=True) return [(c, cls.get_hue(c)) for c in train_rgb] @classmethod def get_xkcd_colors(cls, label="hues"): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor, skip=True) if label == "names": return [ (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] if label == "hues": return [(c, cls.get_hue(c)) for c in rgb_tensor] else: raise ValueError("Please specify `label` as one of ['hues', 'names'].") def setup(self, stage: str): # Assign train/val datasets for use in dataloaders if stage == "fit": self.color_val = self.get_random_colors(self.val_size) if self.train_size > 0: self.color_train = self.get_random_colors(self.train_size) else: self.color_train = self.get_xkcd_colors() # Assign test dataset for use in dataloader(s) if stage == "test": self.color_test = self.get_random_colors(self.val_size) if stage == "predict": # for visualizing self.color_predict = self.get_xkcd_colors() def train_dataloader(self): return DataLoader( self.color_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, ) def val_dataloader(self): return DataLoader( self.color_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, ) def test_dataloader(self): return DataLoader( self.color_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, ) def predict_dataloader(self): return DataLoader( self.color_predict, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, ) def teardown(self, stage: str): # Used to clean-up when the run is finished pass if __name__ == "__main__": cdm = ColorDataModule() cdm.setup("train") print(cdm)