|
@ -40,7 +40,7 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
# print(rgb_norm) |
|
|
# print(rgb_norm) |
|
|
|
|
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
transformed_norm = circle_norm(outputs, target_outputs) # * 2 |
|
|
transformed_norm = circle_norm(outputs, target_outputs) * 2 |
|
|
|
|
|
|
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
N = len(outputs) |
|
|
N = len(outputs) |
|
|