You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

102 lines
2.9 KiB

import lightning as L
import torch
from matplotlib.colors import rgb_to_hsv
from torch.utils.data import DataLoader
from utils import extract_colors, preprocess_data
class ColorDataModule(L.LightningDataModule):
def __init__(
self,
val_size: int = 10_000,
train_size=0,
batch_size: int = 32,
num_workers: int = 3,
):
super().__init__()
self.val_size = val_size
self.train_size = train_size
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
# no state. called from main process.
pass
@classmethod
def get_hue(cls, v: torch.Tensor) -> torch.Tensor:
return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32)
@classmethod
def get_random_colors(cls, size: int):
train_rgb = torch.rand((int(size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
return [(c, cls.get_hue(c)) for c in train_rgb]
@classmethod
def get_xkcd_colors(cls):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
return [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.color_val = self.get_random_colors(self.val_size)
if self.train_size > 0:
self.color_train = self.get_random_colors(self.train_size)
else:
self.color_train = self.get_xkcd_colors()
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.color_test = self.get_random_colors(self.val_size)
if stage == "predict": # for visualizing
self.color_predict = self.get_xkcd_colors()
def train_dataloader(self):
return DataLoader(
self.color_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.color_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
self.color_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def predict_dataloader(self):
return DataLoader(
self.color_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
pass
if __name__ == "__main__":
cdm = ColorDataModule()
cdm.setup("train")
print(cdm)