You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.6 KiB

import lightning as L
import torch
from matplotlib.colors import rgb_to_hsv
from torch.utils.data import DataLoader
from utils import extract_colors, preprocess_data
class ColorDataModule(L.LightningDataModule):
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32):
super().__init__()
self.val_size = val_size
self.batch_size = batch_size
self.train_size = train_size
def prepare_data(self):
# no state. called from main process.
pass
def setup(self, stage: str):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
self.xkcd_colors = [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
if self.train_size > 0:
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
self.train_colors = [
(
train_rgb[i],
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(train_rgb))
]
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
val_rgb = preprocess_data(val_rgb, skip=True)
self.random_colors = [
(
val_rgb[i],
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(val_rgb))
]
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.color_val = self.random_colors
if self.train_size > 0:
self.color_train = self.train_colors
else:
self.color_train = self.xkcd_colors
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.color_test = self.random_colors
if stage == "predict": # for visualizing
self.color_predict = self.xkcd_colors
def train_dataloader(self):
return DataLoader(self.color_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.color_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.color_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.color_predict, batch_size=self.batch_size)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
pass