79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
import torch
|
|
|
|
from utils import PURE_RGB
|
|
|
|
# def smoothness_loss(outputs):
|
|
# # Sort outputs for smoothness calculation
|
|
# sorted_outputs, _ = torch.sort(outputs, dim=0)
|
|
# first_elements = sorted_outputs[:2]
|
|
|
|
# # Concatenate the first element at the end of the sorted_outputs
|
|
# extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
|
|
|
|
# # Calculate smoothness in the sorted outputs
|
|
# first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
|
|
# second_derivative = torch.diff(first_derivative, n=1, dim=0)
|
|
# smoothness_loss = torch.mean(torch.abs(second_derivative))
|
|
# return smoothness_loss
|
|
|
|
|
|
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
|
|
# Distance Preservation Component
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space
|
|
if target_inputs is None:
|
|
target_inputs = inputs
|
|
else:
|
|
assert target_outputs is not None
|
|
if target_outputs is None:
|
|
target_outputs = outputs
|
|
|
|
# Calculate RGB Norm
|
|
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
|
|
# max_rgb_distance = 1
|
|
rgb_norm = (
|
|
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
|
|
/ max_rgb_distance
|
|
)
|
|
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
|
|
rgb_norm = rgb_norm % 1
|
|
# print(rgb_norm)
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity)
|
|
transformed_norm = torch.triu(
|
|
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
|
|
)
|
|
|
|
diff = torch.abs(rgb_norm - transformed_norm)
|
|
# print(diff)
|
|
|
|
return torch.mean(diff)
|
|
|
|
|
|
def separation_loss(red, green, blue):
|
|
# Separation Component
|
|
# Encourages the model to keep R, G, B values equally separated in the transformed space
|
|
red_loss = torch.abs(0 - red)
|
|
green_loss = torch.abs(1 / 3 - green) / (2 / 3)
|
|
blue_loss = torch.abs(2 / 3 - blue) / (2 / 3)
|
|
return red_loss + green_loss + blue_loss
|
|
|
|
|
|
def calculate_separation_loss(model):
|
|
# Wrapper function to calculate separation loss
|
|
outputs = model(PURE_RGB.to(model.device))
|
|
red, green, blue = outputs[0], outputs[1], outputs[2]
|
|
return separation_loss(red, green, blue)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# test preservation loss
|
|
# create torch vector containing pure R, G, B.
|
|
test_input = torch.tensor(
|
|
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32
|
|
)
|
|
test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32)
|
|
|
|
print(preservation_loss(test_input[:3], test_output[:3]))
|
|
rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)
|
|
print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2]))
|