|
|
@ -43,16 +43,17 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): |
|
|
|
transformed_norm = circle_norm(outputs, target_outputs) * 2 |
|
|
|
|
|
|
|
diff = torch.pow(rgb_norm - transformed_norm, 2) |
|
|
|
N = len(outputs) |
|
|
|
N = (N * (N - 1)) / 2 |
|
|
|
# N = torch.count_nonzero(rgb_norm) |
|
|
|
# N = len(outputs) |
|
|
|
# N = (N * (N - 1)) / 2 |
|
|
|
N = torch.count_nonzero(rgb_norm) |
|
|
|
return torch.sum(diff) / N |
|
|
|
|
|
|
|
|
|
|
|
def circle_norm(vector, other_vector): |
|
|
|
# Assumes vectors are of shape (N,1) |
|
|
|
loss_a = torch.triu(torch.abs((vector - other_vector.T))) |
|
|
|
loss_b = torch.triu(1 - torch.abs((vector - other_vector.T))) |
|
|
|
diff = vector - other_vector.T |
|
|
|
loss_a = torch.triu(torch.abs(diff)) |
|
|
|
loss_b = torch.triu(torch.abs(1 - torch.abs(diff))) |
|
|
|
loss = torch.minimum(loss_a, loss_b) |
|
|
|
return loss |
|
|
|
|
|
|
|