You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data(data, skip=True):
|
|
|
|
# Assuming 'data' is a tensor of shape [n_samples, 3]
|
|
|
|
if not skip:
|
|
|
|
# Compute argmin and argmax for each row
|
|
|
|
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
|
|
|
|
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
|
|
|
|
|
|
|
|
# Normalize or scale argmin and argmax if necessary
|
|
|
|
# For example, here I am just dividing by the number of features
|
|
|
|
argmin_values /= data.shape[1] - 1
|
|
|
|
argmax_values /= data.shape[1] - 1
|
|
|
|
|
|
|
|
# Concatenate the argmin and argmax values to the original data
|
|
|
|
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
|
|
|
|
else:
|
|
|
|
new_data = data
|
|
|
|
return new_data
|
|
|
|
|
|
|
|
|
|
|
|
PURE_RGB = preprocess_data(
|
|
|
|
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
|
|
|
|
)
|