|
@ -17,7 +17,7 @@ from utils import PURE_RGB |
|
|
# return smoothness_loss |
|
|
# return smoothness_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
# Distance Preservation Component (or scaled euclidean if given targets) |
|
|
# Distance Preservation Component (or scaled euclidean if given targets) |
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
if target_inputs is None: |
|
|
if target_inputs is None: |
|
@ -40,49 +40,19 @@ def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs |
|
|
# print(rgb_norm) |
|
|
# print(rgb_norm) |
|
|
|
|
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
transformed_norm_a = torch.triu( |
|
|
transformed_norm = circle_norm(outputs, target_outputs) |
|
|
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) |
|
|
|
|
|
) |
|
|
|
|
|
transformed_norm_b = torch.triu( |
|
|
|
|
|
torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) |
|
|
|
|
|
) |
|
|
|
|
|
transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b) |
|
|
|
|
|
|
|
|
|
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
|
|
|
N = len(outputs) |
|
|
|
|
|
return torch.sum(diff) / (N * (N - 1)) / 2 |
|
|
|
|
|
|
|
|
return torch.mean(diff) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
|
|
|
# Distance Preservation Component (or scaled euclidean if given targets) |
|
|
|
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
|
|
|
if target_inputs is None: |
|
|
|
|
|
target_inputs = inputs |
|
|
|
|
|
else: |
|
|
|
|
|
assert target_outputs is not None |
|
|
|
|
|
if target_outputs is None: |
|
|
|
|
|
target_outputs = outputs |
|
|
|
|
|
|
|
|
|
|
|
# Calculate RGB Norm |
|
|
|
|
|
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] |
|
|
|
|
|
# max_rgb_distance = 1 |
|
|
|
|
|
rgb_norm = ( |
|
|
|
|
|
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) |
|
|
|
|
|
/ max_rgb_distance |
|
|
|
|
|
) |
|
|
|
|
|
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space |
|
|
|
|
|
rgb_norm = rgb_norm % 1 |
|
|
|
|
|
# print(rgb_norm) |
|
|
|
|
|
|
|
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
|
|
|
transformed_norm = torch.triu( |
|
|
|
|
|
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
diff = torch.abs(rgb_norm - transformed_norm) |
|
|
|
|
|
# print(diff) |
|
|
|
|
|
|
|
|
|
|
|
return torch.mean(diff) |
|
|
def circle_norm(vector, other_vector): |
|
|
|
|
|
# Assumes vectors are of shape (N,1) |
|
|
|
|
|
loss_a = torch.triu(torch.abs((vector - other_vector.T))) |
|
|
|
|
|
loss_b = torch.triu(1 - torch.abs((vector - other_vector.T))) |
|
|
|
|
|
loss = torch.minimum(loss_a, loss_b) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def separation_loss(red, green, blue): |
|
|
def separation_loss(red, green, blue): |
|
|