@ -52,8 +52,8 @@ class ColorTransformerModel(L.LightningModule):
p_loss = preservation_loss(
inputs,
outputs,
# target_inputs=rgb_tensor,
target_inputs=rgb_tensor,
# target_outputs=self.forward(rgb_tensor),
target_outputs=self.forward(rgb_tensor),
)
alpha = self.hparams.alpha