Michael Pilosov
10 months ago
6 changed files with 109 additions and 75 deletions
@ -1,59 +1,83 @@ |
|||
import torch |
|||
|
|||
# def weighted_loss(inputs, outputs, alpha): |
|||
# # Calculate RGB Norm (Perceptual Difference) |
|||
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|||
from utils import PURE_RGB |
|||
|
|||
# # Calculate 1D Space Norm |
|||
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|||
# def smoothness_loss(outputs): |
|||
# # Sort outputs for smoothness calculation |
|||
# sorted_outputs, _ = torch.sort(outputs, dim=0) |
|||
# first_elements = sorted_outputs[:2] |
|||
|
|||
# # Weighted Loss |
|||
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm |
|||
# return torch.mean(loss) |
|||
# # Concatenate the first element at the end of the sorted_outputs |
|||
# extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) |
|||
|
|||
|
|||
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): |
|||
# # Calculate RGB Norm |
|||
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|||
|
|||
# # Calculate 1D Space Norm |
|||
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|||
|
|||
# # Identify Distinct Colors (based on a threshold in RGB space) |
|||
# distinct_colors = rgb_norm > distinct_threshold |
|||
|
|||
# # Penalty for Distinct Colors being too close in the transformed space |
|||
# # Here we do not take the mean yet, to avoid double averaging |
|||
# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors |
|||
|
|||
# # Combined Loss |
|||
# # The mean is taken here, once, after all components are combined |
|||
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty |
|||
# return torch.mean(loss) |
|||
# # Calculate smoothness in the sorted outputs |
|||
# first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) |
|||
# second_derivative = torch.diff(first_derivative, n=1, dim=0) |
|||
# smoothness_loss = torch.mean(torch.abs(second_derivative)) |
|||
# return smoothness_loss |
|||
|
|||
|
|||
def preservation_loss(inputs, outputs): |
|||
# Calculate RGB Norm |
|||
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|||
|
|||
# Calculate 1D Space Norm |
|||
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|||
|
|||
# Distance Preservation Component |
|||
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|||
return torch.mean(torch.abs(rgb_norm - transformed_norm)) |
|||
|
|||
|
|||
def smoothness_loss(outputs): |
|||
# Sort outputs for smoothness calculation |
|||
sorted_outputs, _ = torch.sort(outputs, dim=0) |
|||
first_elements = sorted_outputs[:2] |
|||
|
|||
# Concatenate the first element at the end of the sorted_outputs |
|||
extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) |
|||
|
|||
# Calculate smoothness in the sorted outputs |
|||
first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) |
|||
second_derivative = torch.diff(first_derivative, n=1, dim=0) |
|||
smoothness_loss = torch.mean(torch.abs(second_derivative)) |
|||
return smoothness_loss |
|||
# Calculate RGB Norm |
|||
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] |
|||
rgb_norm = ( |
|||
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)) |
|||
/ max_rgb_distance |
|||
) |
|||
rgb_norm = ( |
|||
rgb_norm % 1 |
|||
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space |
|||
# print(rgb_norm) |
|||
|
|||
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|||
transformed_norm = torch.triu( |
|||
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1) |
|||
) |
|||
|
|||
diff = torch.abs(rgb_norm - transformed_norm) |
|||
# print(diff) |
|||
|
|||
return torch.mean(diff) |
|||
|
|||
|
|||
def separation_loss(red, green, blue): |
|||
# Separation Component |
|||
# Encourages the model to keep R, G, B values equally separated in the transformed space |
|||
red, green, blue = red % 1, green % 1, blue % 1 |
|||
red_green_distance = torch.min( |
|||
torch.abs((red - green)), torch.abs((1 + red - green)) |
|||
) |
|||
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue))) |
|||
green_blue_distance = torch.min( |
|||
torch.abs((green - blue)), torch.abs((1 + green - blue)) |
|||
) |
|||
# print(red_green_distance, red_blue_distance, green_blue_distance) |
|||
# we want these distances to be equal to one another |
|||
return ( |
|||
torch.abs(red_green_distance - red_blue_distance) |
|||
+ torch.abs(red_green_distance - green_blue_distance) |
|||
+ torch.abs(red_blue_distance - green_blue_distance) |
|||
) |
|||
|
|||
|
|||
def calculate_separation_loss(model): |
|||
# Wrapper function to calculate separation loss |
|||
outputs = model(PURE_RGB) |
|||
red, green, blue = outputs[0], outputs[1], outputs[2] |
|||
return separation_loss(red, green, blue) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
# test preservation loss |
|||
# create torch vector containing pure R, G, B. |
|||
test_input = torch.tensor( |
|||
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32 |
|||
) |
|||
test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32) |
|||
|
|||
print(preservation_loss(test_input[:3], test_output[:3])) |
|||
rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32) |
|||
print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2])) |
|||
|
@ -0,0 +1,24 @@ |
|||
import torch |
|||
|
|||
|
|||
def preprocess_data(data): |
|||
# Assuming 'data' is a tensor of shape [n_samples, 3] |
|||
|
|||
# Compute argmin and argmax for each row |
|||
argmin_values = torch.argmin(data, dim=1, keepdim=True).float() |
|||
argmax_values = torch.argmax(data, dim=1, keepdim=True).float() |
|||
|
|||
# Normalize or scale argmin and argmax if necessary |
|||
# For example, here I am just dividing by the number of features |
|||
argmin_values /= data.shape[1] - 1 |
|||
argmax_values /= data.shape[1] - 1 |
|||
|
|||
# Concatenate the argmin and argmax values to the original data |
|||
new_data = torch.cat((data, argmin_values, argmax_values), dim=1) |
|||
|
|||
return new_data |
|||
|
|||
|
|||
PURE_RGB = preprocess_data( |
|||
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32) |
|||
) |
Loading…
Reference in new issue