diff --git a/model.py b/model.py index 477b094..dd6029d 100644 --- a/model.py +++ b/model.py @@ -58,8 +58,8 @@ class ColorTransformerModel(L.LightningModule): alpha = self.hparams.alpha # N = len(outputs) - distance = circle_norm(outputs, labels).mean() - # distance = torch.norm(outputs - labels).mean() + # distance = circle_norm(outputs, labels).mean() + distance = torch.norm(outputs - labels).mean() # Backprop with this: loss = (1 - alpha) * p_loss + alpha * distance