You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.6 KiB
78 lines
2.6 KiB
import matplotlib.colors as mcolors
|
|
import torch
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
def extract_colors():
|
|
# Extracting the list of xkcd colors as RGB triples
|
|
xkcd_colors = mcolors.XKCD_COLORS
|
|
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
|
|
|
|
# Extracting the list of xkcd color names
|
|
xkcd_color_names = list(xkcd_colors.keys())
|
|
|
|
# Convert the list of RGB triples to a PyTorch tensor
|
|
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
|
|
return rgb_tensor, xkcd_color_names
|
|
|
|
|
|
def create_dataloader(**kwargs):
|
|
rgb_tensor, _ = extract_colors()
|
|
rgb_tensor = preprocess_data(rgb_tensor)
|
|
# Creating a dataset and data loader
|
|
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
|
|
train_dataloader = DataLoader(dataset, **kwargs)
|
|
return train_dataloader
|
|
|
|
|
|
def create_named_dataloader(**kwargs):
|
|
rgb_tensor, xkcd_color_names = extract_colors()
|
|
rgb_tensor = preprocess_data(rgb_tensor)
|
|
# Creating a dataset with RGB values and their corresponding color names
|
|
dataset_with_names = [
|
|
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor))
|
|
]
|
|
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
|
|
return train_dataloader_with_names
|
|
|
|
|
|
def preprocess_data(data):
|
|
# Assuming 'data' is a tensor of shape [n_samples, 3]
|
|
|
|
# Compute argmin and argmax for each row
|
|
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
|
|
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
|
|
|
|
# Normalize or scale argmin and argmax if necessary
|
|
# For example, here I am just dividing by the number of features
|
|
argmin_values /= data.shape[1]
|
|
argmax_values /= data.shape[1]
|
|
|
|
# Concatenate the argmin and argmax values to the original data
|
|
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
|
|
|
|
return new_data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
batch_size = 4
|
|
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)
|
|
train_dataloader_with_names = create_named_dataloader(
|
|
batch_size=batch_size, shuffle=True
|
|
)
|
|
|
|
# Extract a sample from the DataLoader
|
|
sample_data = next(iter(train_dataloader))
|
|
|
|
# Sample RGB values and their corresponding dummy labels
|
|
sample_rgb_values, _ = sample_data
|
|
|
|
print(sample_rgb_values)
|
|
|
|
# Extract a sample from the new DataLoader
|
|
sample_data_with_names = next(iter(train_dataloader_with_names))
|
|
|
|
# Sample RGB values and their corresponding color names
|
|
sample_rgb_values_with_names, sample_color_names = sample_data_with_names
|
|
|
|
print(sample_rgb_values_with_names, sample_color_names)
|
|
|