diff --git a/model.py b/model.py index 67fd130..e3e2276 100644 --- a/model.py +++ b/model.py @@ -64,6 +64,18 @@ class ColorTransformerModel(L.LightningModule): self.log("s_loss", s_loss) return loss + def validation_step(self): + inputs, labels = batch # these are true HSV labels - no learning allowed. + outputs = self.forward(inputs) + distance = torch.minimum( + torch.abs(outputs - labels), torch.abs(1 + outputs - labels) + ) + mean_loss = torch.mean(distance) + max_loss = torch.max(distance) + self.log("val_mean_loss", mean_loss) + self.log("val_max_loss", max_loss) + return mean_loss + def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(),