|
@ -58,8 +58,8 @@ class ColorTransformerModel(L.LightningModule): |
|
|
alpha = self.hparams.alpha |
|
|
alpha = self.hparams.alpha |
|
|
|
|
|
|
|
|
# N = len(outputs) |
|
|
# N = len(outputs) |
|
|
distance = circle_norm(outputs, labels).mean() |
|
|
# distance = circle_norm(outputs, labels).mean() |
|
|
# distance = torch.norm(outputs - labels).mean() |
|
|
distance = torch.norm(outputs - labels).mean() |
|
|
|
|
|
|
|
|
# Backprop with this: |
|
|
# Backprop with this: |
|
|
loss = (1 - alpha) * p_loss + alpha * distance |
|
|
loss = (1 - alpha) * p_loss + alpha * distance |
|
|