You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

83 lines
3.0 KiB

import torch
from utils import PURE_RGB
# def smoothness_loss(outputs):
# # Sort outputs for smoothness calculation
# sorted_outputs, _ = torch.sort(outputs, dim=0)
# first_elements = sorted_outputs[:2]
# # Concatenate the first element at the end of the sorted_outputs
# extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
# # Calculate smoothness in the sorted outputs
# first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
# second_derivative = torch.diff(first_derivative, n=1, dim=0)
# smoothness_loss = torch.mean(torch.abs(second_derivative))
# return smoothness_loss
def preservation_loss(inputs, outputs):
# Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
rgb_norm = (
rgb_norm % 1
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu(
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1)
)
diff = torch.abs(rgb_norm - transformed_norm)
# print(diff)
return torch.mean(diff)
def separation_loss(red, green, blue):
# Separation Component
# Encourages the model to keep R, G, B values equally separated in the transformed space
red, green, blue = red % 1, green % 1, blue % 1
red_green_distance = torch.min(
torch.abs((red - green)), torch.abs((1 + red - green))
)
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue)))
green_blue_distance = torch.min(
torch.abs((green - blue)), torch.abs((1 + green - blue))
)
# print(red_green_distance, red_blue_distance, green_blue_distance)
# we want these distances to be equal to one another
return (
torch.abs(red_green_distance - red_blue_distance)
+ torch.abs(red_green_distance - green_blue_distance)
+ torch.abs(red_blue_distance - green_blue_distance)
)
def calculate_separation_loss(model):
# Wrapper function to calculate separation loss
outputs = model(PURE_RGB)
red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue)
if __name__ == "__main__":
# test preservation loss
# create torch vector containing pure R, G, B.
test_input = torch.tensor(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32
)
test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32)
print(preservation_loss(test_input[:3], test_output[:3]))
rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)
print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2]))