colors/dataloader.py

60 lines
2.2 KiB
Python
Raw Normal View History

2023-12-30 04:37:06 +00:00
import torch
from torch.utils.data import DataLoader, TensorDataset
from utils import extract_colors, preprocess_data
2023-12-30 04:37:06 +00:00
2024-01-16 04:37:22 +00:00
def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
2024-01-15 19:18:28 +00:00
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
2024-01-16 04:37:22 +00:00
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
2023-12-30 04:37:06 +00:00
# Creating a dataset and data loader
2023-12-31 06:17:15 +00:00
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
2023-12-30 04:37:06 +00:00
train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader
2024-01-16 04:37:22 +00:00
def create_gray_supplement(N: int = 50, skip: bool = True):
2023-12-31 07:00:25 +00:00
linear_space = torch.linspace(0, 1, N)
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
2024-01-16 04:37:22 +00:00
gray_tensor = preprocess_data(gray_tensor, skip=skip)
2023-12-31 07:00:25 +00:00
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
2024-01-16 04:37:22 +00:00
def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
2023-12-30 04:37:06 +00:00
rgb_tensor, xkcd_color_names = extract_colors()
2024-01-16 04:37:22 +00:00
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
2023-12-30 04:37:06 +00:00
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
2023-12-31 07:00:25 +00:00
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
2023-12-30 04:37:06 +00:00
]
2024-01-14 03:11:49 +00:00
if N > 0:
2024-01-16 04:37:22 +00:00
dataset_with_names += create_gray_supplement(N, skip=skip)
2023-12-30 04:37:06 +00:00
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names
if __name__ == "__main__":
batch_size = 4
2024-01-15 19:18:28 +00:00
train_dataloader = create_dataloader(N=1e6, batch_size=batch_size, shuffle=True)
print(len(train_dataloader.dataset))
2023-12-30 04:37:06 +00:00
train_dataloader_with_names = create_named_dataloader(
batch_size=batch_size, shuffle=True
)
# Extract a sample from the DataLoader
sample_data = next(iter(train_dataloader))
# Sample RGB values and their corresponding dummy labels
sample_rgb_values, _ = sample_data
print(sample_rgb_values)
# Extract a sample from the new DataLoader
sample_data_with_names = next(iter(train_dataloader_with_names))
# Sample RGB values and their corresponding color names
sample_rgb_values_with_names, sample_color_names = sample_data_with_names
print(sample_rgb_values_with_names, sample_color_names)