|
@ -64,6 +64,18 @@ class ColorTransformerModel(L.LightningModule): |
|
|
self.log("s_loss", s_loss) |
|
|
self.log("s_loss", s_loss) |
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self): |
|
|
|
|
|
inputs, labels = batch # these are true HSV labels - no learning allowed. |
|
|
|
|
|
outputs = self.forward(inputs) |
|
|
|
|
|
distance = torch.minimum( |
|
|
|
|
|
torch.abs(outputs - labels), torch.abs(1 + outputs - labels) |
|
|
|
|
|
) |
|
|
|
|
|
mean_loss = torch.mean(distance) |
|
|
|
|
|
max_loss = torch.max(distance) |
|
|
|
|
|
self.log("val_mean_loss", mean_loss) |
|
|
|
|
|
self.log("val_max_loss", max_loss) |
|
|
|
|
|
return mean_loss |
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.SGD( |
|
|
optimizer = torch.optim.SGD( |
|
|
self.parameters(), |
|
|
self.parameters(), |
|
|