diff --git a/model.py b/model.py index 775fdda..339d89b 100644 --- a/model.py +++ b/model.py @@ -68,11 +68,12 @@ class ColorTransformerModel(L.LightningModule): loss = (1 - alpha) * p_loss + alpha * distance # p_loss is unsupervised (preserve relative distances - either in-batch or to-target) # distance is supervised. - self.log("hp_metric", loss) + self.log("hp_metric", distance) # Log all losses individually self.log("train_pres", p_loss) self.log("train_mse", distance) + self.log("train_loss", loss) return loss def validation_step(self, batch): diff --git a/newsearch.py b/newsearch.py index 70b362d..6b30c69 100644 --- a/newsearch.py +++ b/newsearch.py @@ -27,7 +27,7 @@ learning_rate_values = [1e-3] # learning_rate_values = [5e-4] # alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised. -alpha_values = [1.0] +alpha_values = [0.99] # widths = [2**k for k in range(4, 13)] # depths = [1, 2, 4, 8, 16] widths, depths = [512], [1, 2, 4]