|
|
@ -68,11 +68,12 @@ class ColorTransformerModel(L.LightningModule): |
|
|
|
loss = (1 - alpha) * p_loss + alpha * distance |
|
|
|
# p_loss is unsupervised (preserve relative distances - either in-batch or to-target) |
|
|
|
# distance is supervised. |
|
|
|
self.log("hp_metric", loss) |
|
|
|
self.log("hp_metric", distance) |
|
|
|
|
|
|
|
# Log all losses individually |
|
|
|
self.log("train_pres", p_loss) |
|
|
|
self.log("train_mse", distance) |
|
|
|
self.log("train_loss", loss) |
|
|
|
return loss |
|
|
|
|
|
|
|
def validation_step(self, batch): |
|
|
|