diff --git a/losses.py b/losses.py index e68e465..6a6ca46 100644 --- a/losses.py +++ b/losses.py @@ -43,16 +43,17 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): transformed_norm = circle_norm(outputs, target_outputs) * 2 diff = torch.pow(rgb_norm - transformed_norm, 2) - N = len(outputs) - N = (N * (N - 1)) / 2 - # N = torch.count_nonzero(rgb_norm) + # N = len(outputs) + # N = (N * (N - 1)) / 2 + N = torch.count_nonzero(rgb_norm) return torch.sum(diff) / N def circle_norm(vector, other_vector): # Assumes vectors are of shape (N,1) - loss_a = torch.triu(torch.abs((vector - other_vector.T))) - loss_b = torch.triu(1 - torch.abs((vector - other_vector.T))) + diff = vector - other_vector.T + loss_a = torch.triu(torch.abs(diff)) + loss_b = torch.triu(torch.abs(1 - torch.abs(diff))) loss = torch.minimum(loss_a, loss_b) return loss