import matplotlib.colors as mcolors import torch from torch.utils.data import DataLoader, TensorDataset def extract_colors(): # Extracting the list of xkcd colors as RGB triples xkcd_colors = mcolors.XKCD_COLORS rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] # Extracting the list of xkcd color names xkcd_color_names = list(xkcd_colors.keys()) # Convert the list of RGB triples to a PyTorch tensor rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) return rgb_tensor, xkcd_color_names def create_dataloader(N: int = 50, **kwargs): rgb_tensor, _ = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset and data loader dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) train_dataloader = DataLoader(dataset, **kwargs) return train_dataloader def create_gray_supplement(N: int = 50): linear_space = torch.linspace(0, 1, N) gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) gray_tensor = preprocess_data(gray_tensor) return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] def create_named_dataloader(N: int = 50, **kwargs): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset with RGB values and their corresponding color names dataset_with_names = [ (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] dataset_with_names += create_gray_supplement(N) train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) return train_dataloader_with_names def preprocess_data(data): # Assuming 'data' is a tensor of shape [n_samples, 3] # Compute argmin and argmax for each row argmin_values = torch.argmin(data, dim=1, keepdim=True).float() argmax_values = torch.argmax(data, dim=1, keepdim=True).float() # Normalize or scale argmin and argmax if necessary # For example, here I am just dividing by the number of features argmin_values /= data.shape[1] argmax_values /= data.shape[1] # Concatenate the argmin and argmax values to the original data new_data = torch.cat((data, argmin_values, argmax_values), dim=1) return new_data if __name__ == "__main__": batch_size = 4 train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) train_dataloader_with_names = create_named_dataloader( batch_size=batch_size, shuffle=True ) # Extract a sample from the DataLoader sample_data = next(iter(train_dataloader)) # Sample RGB values and their corresponding dummy labels sample_rgb_values, _ = sample_data print(sample_rgb_values) # Extract a sample from the new DataLoader sample_data_with_names = next(iter(train_dataloader_with_names)) # Sample RGB values and their corresponding color names sample_rgb_values_with_names, sample_color_names = sample_data_with_names print(sample_rgb_values_with_names, sample_color_names)