syntax, no diff

This commit is contained in:
Michael Pilosov, PhD 2024-01-28 09:54:38 +00:00
parent 2af491c324
commit 0ae8414481

View File

@ -51,9 +51,9 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
def circle_norm(vector, other_vector):
# Assumes vectors are of shape (N,1)
diff = vector - other_vector.T
loss_a = torch.triu(torch.abs(diff))
loss_b = torch.triu(torch.abs(1 - torch.abs(diff)))
diff = torch.abs(vector - other_vector.T)
loss_a = torch.triu(diff)
loss_b = torch.triu(torch.abs(1 - diff))
loss = torch.minimum(loss_a, loss_b)
return loss