|
|
@ -1,6 +1,6 @@ |
|
|
|
import torch |
|
|
|
|
|
|
|
from utils import PURE_RGB |
|
|
|
from utils import RGBMYC_ANCHOR |
|
|
|
|
|
|
|
# def smoothness_loss(outputs): |
|
|
|
# # Sort outputs for smoothness calculation |
|
|
@ -40,11 +40,11 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
|
# print(rgb_norm) |
|
|
|
|
|
|
|
# Calculate 1D Space Norm (modulo 1 to account for circularity) |
|
|
|
transformed_norm = circle_norm(outputs, target_outputs) |
|
|
|
transformed_norm = circle_norm(outputs, target_outputs) * 2 |
|
|
|
|
|
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
|
N = len(outputs) |
|
|
|
return torch.sum(diff) / (N * (N - 1)) / 2 |
|
|
|
N = torch.count_nonzero(rgb_norm) |
|
|
|
return torch.sum(diff) / N |
|
|
|
|
|
|
|
|
|
|
|
def circle_norm(vector, other_vector): |
|
|
@ -68,7 +68,7 @@ def separation_loss(red, green, blue): |
|
|
|
def calculate_separation_loss(model): |
|
|
|
# TODO: remove |
|
|
|
# Wrapper function to calculate separation loss |
|
|
|
outputs = model(PURE_RGB.to(model.device)) |
|
|
|
outputs = model(RGBMYC_ANCHOR.to(model.device)) |
|
|
|
red, green, blue = outputs[0], outputs[1], outputs[2] |
|
|
|
return separation_loss(red, green, blue) |
|
|
|
|
|
|
|