You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.4 KiB
40 lines
1.4 KiB
import matplotlib.colors as mcolors
|
|
import torch
|
|
|
|
|
|
def preprocess_data(data, skip: bool = False):
|
|
# Assuming 'data' is a tensor of shape [n_samples, 3]
|
|
if not skip:
|
|
# Compute argmin and argmax for each row
|
|
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
|
|
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
|
|
|
|
# Normalize or scale argmin and argmax if necessary
|
|
# For example, here I am just dividing by the number of features
|
|
argmin_values /= data.shape[1] - 1
|
|
argmax_values /= data.shape[1] - 1
|
|
|
|
# Concatenate the argmin and argmax values to the original data
|
|
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
|
|
else:
|
|
new_data = data
|
|
return new_data
|
|
|
|
|
|
def extract_colors():
|
|
# Extracting the list of xkcd colors as RGB triples
|
|
xkcd_colors = mcolors.XKCD_COLORS
|
|
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
|
|
|
|
# Extracting the list of xkcd color names
|
|
xkcd_color_names = list(xkcd_colors.keys())
|
|
|
|
# Convert the list of RGB triples to a PyTorch tensor
|
|
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
|
|
return rgb_tensor, xkcd_color_names
|
|
|
|
|
|
PURE_RGB = preprocess_data(torch.eye(3) + torch.eye(3)[:, [1, 2, 0]])
|
|
PURE_HSV = torch.tensor(
|
|
[[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32
|
|
)
|
|
|