diff --git a/model.py b/model.py index 5b1eb4c..4dce517 100644 --- a/model.py +++ b/model.py @@ -52,8 +52,8 @@ class ColorTransformerModel(L.LightningModule): p_loss = preservation_loss( inputs, outputs, - # target_inputs=rgb_tensor, - # target_outputs=self.forward(rgb_tensor), + target_inputs=rgb_tensor, + target_outputs=self.forward(rgb_tensor), ) alpha = self.hparams.alpha