colors/utils.py

43 lines
1.4 KiB
Python
Raw Normal View History

import matplotlib.colors as mcolors
2024-01-14 03:11:49 +00:00
import torch
2024-01-16 04:37:22 +00:00
def preprocess_data(data, skip: bool = False):
2024-01-14 03:11:49 +00:00
# Assuming 'data' is a tensor of shape [n_samples, 3]
2024-01-14 06:04:19 +00:00
if not skip:
# Compute argmin and argmax for each row
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
# Normalize or scale argmin and argmax if necessary
# For example, here I am just dividing by the number of features
argmin_values /= data.shape[1] - 1
argmax_values /= data.shape[1] - 1
# Concatenate the argmin and argmax values to the original data
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
else:
new_data = data
2024-01-14 03:11:49 +00:00
return new_data
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
xkcd_colors = mcolors.XKCD_COLORS
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
# Extracting the list of xkcd color names
xkcd_color_names = list(xkcd_colors.keys())
# Convert the list of RGB triples to a PyTorch tensor
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
return rgb_tensor, xkcd_color_names
2024-01-25 06:12:27 +00:00
PURE_RGB = preprocess_data(
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
)
2024-01-16 05:19:54 +00:00
PURE_HSV = torch.tensor(
[[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32
2024-01-14 03:11:49 +00:00
)