@ -58,8 +58,8 @@ class ColorTransformerModel(L.LightningModule):
alpha = self.hparams.alpha
# N = len(outputs)
distance = circle_norm(outputs, labels).mean()
# distance = torch.norm(outputs - labels).mean()
# distance = circle_norm(outputs, labels).mean()
distance = torch.norm(outputs - labels).mean()
# Backprop with this:
loss = (1 - alpha) * p_loss + alpha * distance