You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.6 KiB
77 lines
2.6 KiB
import lightning as L
|
|
import torch
|
|
from matplotlib.colors import rgb_to_hsv
|
|
from torch.utils.data import DataLoader
|
|
|
|
from utils import extract_colors, preprocess_data
|
|
|
|
|
|
class ColorDataModule(L.LightningDataModule):
|
|
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32):
|
|
super().__init__()
|
|
self.val_size = val_size
|
|
self.batch_size = batch_size
|
|
self.train_size = train_size
|
|
|
|
def prepare_data(self):
|
|
# no state. called from main process.
|
|
pass
|
|
|
|
def setup(self, stage: str):
|
|
rgb_tensor, xkcd_color_names = extract_colors()
|
|
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
|
|
self.xkcd_colors = [
|
|
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
|
|
for i in range(len(rgb_tensor))
|
|
]
|
|
if self.train_size > 0:
|
|
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
|
|
train_rgb = preprocess_data(train_rgb, skip=True)
|
|
self.train_colors = [
|
|
(
|
|
train_rgb[i],
|
|
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32),
|
|
)
|
|
for i in range(len(train_rgb))
|
|
]
|
|
|
|
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
|
|
val_rgb = preprocess_data(val_rgb, skip=True)
|
|
self.random_colors = [
|
|
(
|
|
val_rgb[i],
|
|
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32),
|
|
)
|
|
for i in range(len(val_rgb))
|
|
]
|
|
|
|
# Assign train/val datasets for use in dataloaders
|
|
if stage == "fit":
|
|
self.color_val = self.random_colors
|
|
if self.train_size > 0:
|
|
self.color_train = self.train_colors
|
|
else:
|
|
self.color_train = self.xkcd_colors
|
|
|
|
# Assign test dataset for use in dataloader(s)
|
|
if stage == "test":
|
|
self.color_test = self.random_colors
|
|
|
|
if stage == "predict": # for visualizing
|
|
self.color_predict = self.xkcd_colors
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(self.color_train, batch_size=self.batch_size)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(self.color_val, batch_size=self.batch_size)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(self.color_test, batch_size=self.batch_size)
|
|
|
|
def predict_dataloader(self):
|
|
return DataLoader(self.color_predict, batch_size=self.batch_size)
|
|
|
|
def teardown(self, stage: str):
|
|
# Used to clean-up when the run is finished
|
|
pass
|
|
|