diff --git a/losses.py b/losses.py index 6a6ca46..edc1f94 100644 --- a/losses.py +++ b/losses.py @@ -51,9 +51,9 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): def circle_norm(vector, other_vector): # Assumes vectors are of shape (N,1) - diff = vector - other_vector.T - loss_a = torch.triu(torch.abs(diff)) - loss_b = torch.triu(torch.abs(1 - torch.abs(diff))) + diff = torch.abs(vector - other_vector.T) + loss_a = torch.triu(diff) + loss_b = torch.triu(torch.abs(1 - diff)) loss = torch.minimum(loss_a, loss_b) return loss