You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
2.0 KiB

11 months ago
import matplotlib.colors as mcolors
import torch
from torch.utils.data import DataLoader, TensorDataset
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
xkcd_colors = mcolors.XKCD_COLORS
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
# Extracting the list of xkcd color names
xkcd_color_names = list(xkcd_colors.keys())
# Convert the list of RGB triples to a PyTorch tensor
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
return rgb_tensor, xkcd_color_names
def create_dataloader(**kwargs):
rgb_tensor, _ = extract_colors()
# Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels
train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader
def create_named_dataloader(**kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor))
]
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names
if __name__ == "__main__":
batch_size = 4
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)
train_dataloader_with_names = create_named_dataloader(
batch_size=batch_size, shuffle=True
)
# Extract a sample from the DataLoader
sample_data = next(iter(train_dataloader))
# Sample RGB values and their corresponding dummy labels
sample_rgb_values, _ = sample_data
print(sample_rgb_values)
# Extract a sample from the new DataLoader
sample_data_with_names = next(iter(train_dataloader_with_names))
# Sample RGB values and their corresponding color names
sample_rgb_values_with_names, sample_color_names = sample_data_with_names
print(sample_rgb_values_with_names, sample_color_names)