Browse Source

tweak metric

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
2af491c324
  1. 11
      losses.py

11
losses.py

@ -43,16 +43,17 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
transformed_norm = circle_norm(outputs, target_outputs) * 2
diff = torch.pow(rgb_norm - transformed_norm, 2)
N = len(outputs)
N = (N * (N - 1)) / 2
# N = torch.count_nonzero(rgb_norm)
# N = len(outputs)
# N = (N * (N - 1)) / 2
N = torch.count_nonzero(rgb_norm)
return torch.sum(diff) / N
def circle_norm(vector, other_vector):
# Assumes vectors are of shape (N,1)
loss_a = torch.triu(torch.abs((vector - other_vector.T)))
loss_b = torch.triu(1 - torch.abs((vector - other_vector.T)))
diff = vector - other_vector.T
loss_a = torch.triu(torch.abs(diff))
loss_b = torch.triu(torch.abs(1 - torch.abs(diff)))
loss = torch.minimum(loss_a, loss_b)
return loss

Loading…
Cancel
Save