From 012c7b7c68a313bdbc72076339ea568c744b129d Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sat, 27 Jan 2024 07:48:04 +0000 Subject: [PATCH] add validation step --- model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/model.py b/model.py index 67fd130..e3e2276 100644 --- a/model.py +++ b/model.py @@ -64,6 +64,18 @@ class ColorTransformerModel(L.LightningModule): self.log("s_loss", s_loss) return loss + def validation_step(self): + inputs, labels = batch # these are true HSV labels - no learning allowed. + outputs = self.forward(inputs) + distance = torch.minimum( + torch.abs(outputs - labels), torch.abs(1 + outputs - labels) + ) + mean_loss = torch.mean(distance) + max_loss = torch.max(distance) + self.log("val_mean_loss", mean_loss) + self.log("val_max_loss", max_loss) + return mean_loss + def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(),