import matplotlib.colors as mcolors import torch def preprocess_data(data, skip: bool = False): # Assuming 'data' is a tensor of shape [n_samples, 3] if not skip: # Compute argmin and argmax for each row argmin_values = torch.argmin(data, dim=1, keepdim=True).float() argmax_values = torch.argmax(data, dim=1, keepdim=True).float() # Normalize or scale argmin and argmax if necessary # For example, here I am just dividing by the number of features argmin_values /= data.shape[1] - 1 argmax_values /= data.shape[1] - 1 # Concatenate the argmin and argmax values to the original data new_data = torch.cat((data, argmin_values, argmax_values), dim=1) else: new_data = data return new_data def extract_colors(): # Extracting the list of xkcd colors as RGB triples xkcd_colors = mcolors.XKCD_COLORS rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] # Extracting the list of xkcd color names xkcd_color_names = list(xkcd_colors.keys()) # Convert the list of RGB triples to a PyTorch tensor rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) return rgb_tensor, xkcd_color_names PURE_RGB = preprocess_data(torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)) PURE_HSV = torch.tensor( [[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32 )