You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.0 KiB
103 lines
3.0 KiB
import lightning as L
|
|
import torch
|
|
from matplotlib.colors import rgb_to_hsv
|
|
from torch.utils.data import DataLoader
|
|
|
|
from utils import extract_colors, preprocess_data
|
|
|
|
|
|
class ColorDataModule(L.LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
val_size: int = 10_000,
|
|
train_size=0,
|
|
batch_size: int = 32,
|
|
num_workers: int = 3,
|
|
):
|
|
super().__init__()
|
|
self.val_size = val_size
|
|
self.train_size = train_size
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
|
|
def prepare_data(self):
|
|
# no state. called from main process.
|
|
pass
|
|
|
|
@classmethod
|
|
def get_hue(cls, v: torch.Tensor) -> torch.Tensor:
|
|
return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32)
|
|
|
|
@classmethod
|
|
def get_random_colors(cls, size: int):
|
|
train_rgb = torch.rand((int(size), 3), dtype=torch.float32)
|
|
train_rgb = preprocess_data(train_rgb, skip=True)
|
|
return [(c, cls.get_hue(c)) for c in train_rgb]
|
|
|
|
@classmethod
|
|
def get_xkcd_colors(cls):
|
|
rgb_tensor, xkcd_color_names = extract_colors()
|
|
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
|
|
# return [
|
|
# (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
|
|
# for i in range(len(rgb_tensor))
|
|
# ]
|
|
return [(c, cls.get_hue(c)) for c in rgb_tensor]
|
|
|
|
def setup(self, stage: str):
|
|
# Assign train/val datasets for use in dataloaders
|
|
if stage == "fit":
|
|
self.color_val = self.get_random_colors(self.val_size)
|
|
if self.train_size > 0:
|
|
self.color_train = self.get_random_colors(self.train_size)
|
|
else:
|
|
self.color_train = self.get_xkcd_colors()
|
|
|
|
# Assign test dataset for use in dataloader(s)
|
|
if stage == "test":
|
|
self.color_test = self.get_random_colors(self.val_size)
|
|
|
|
if stage == "predict": # for visualizing
|
|
self.color_predict = self.get_xkcd_colors()
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
self.color_train,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=True,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self.color_val,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=False,
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(
|
|
self.color_test,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=True,
|
|
)
|
|
|
|
def predict_dataloader(self):
|
|
return DataLoader(
|
|
self.color_predict,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=False,
|
|
)
|
|
|
|
def teardown(self, stage: str):
|
|
# Used to clean-up when the run is finished
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cdm = ColorDataModule()
|
|
cdm.setup("train")
|
|
print(cdm)
|
|
|