|
|
@ -17,6 +17,42 @@ from utils import PURE_RGB |
|
|
|
# return smoothness_loss |
|
|
|
|
|
|
|
|
|
|
|
def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
|
# Distance Preservation Component (or scaled euclidean if given targets) |
|
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
|
if target_inputs is None: |
|
|
|
target_inputs = inputs |
|
|
|
else: |
|
|
|
assert target_outputs is not None |
|
|
|
if target_outputs is None: |
|
|
|
target_outputs = outputs |
|
|
|
|
|
|
|
# Calculate RGB Norm |
|
|
|
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] |
|
|
|
# max_rgb_distance = 1 |
|
|
|
rgb_norm = ( |
|
|
|
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) |
|
|
|
/ max_rgb_distance |
|
|
|
) |
|
|
|
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space |
|
|
|
# rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent. |
|
|
|
# yes it connects black and white, but also complimentary colors to primary |
|
|
|
# print(rgb_norm) |
|
|
|
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
|
transformed_norm_a = torch.triu( |
|
|
|
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) |
|
|
|
) |
|
|
|
transformed_norm_b = torch.triu( |
|
|
|
torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) |
|
|
|
) |
|
|
|
transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b) |
|
|
|
|
|
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
|
|
|
|
|
return torch.mean(diff) |
|
|
|
|
|
|
|
|
|
|
|
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
|
# Distance Preservation Component (or scaled euclidean if given targets) |
|
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
|