Browse Source

allow for mix of supervised and not with alpha

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
b5d9e725b3
  1. 6
      model.py
  2. 7
      newsearch.py

6
model.py

@ -49,7 +49,7 @@ class ColorTransformerModel(L.LightningModule):
inputs, inputs,
outputs, outputs,
) )
# alpha = self.hparams.alpha # TODO: decide what to do with this... alpha = self.hparams.alpha
# loss = p_loss # loss = p_loss
# distance = torch.minimum( # distance = torch.minimum(
@ -58,14 +58,14 @@ class ColorTransformerModel(L.LightningModule):
distance = torch.norm(outputs - labels).mean() distance = torch.norm(outputs - labels).mean()
# Backprop with this: # Backprop with this:
loss = p_loss loss = (1 - alpha) * p_loss + alpha * distance
# p_loss is unsupervised # p_loss is unsupervised
# distance is supervised. # distance is supervised.
self.log("hp_metric", loss) self.log("hp_metric", loss)
# Log all losses individually # Log all losses individually
self.log("train_mse", distance)
self.log("train_pres", p_loss) self.log("train_pres", p_loss)
self.log("train_mse", distance)
return loss return loss
def validation_step(self, batch): def validation_step(self, batch):

7
newsearch.py

@ -24,10 +24,13 @@ NUM_JOBS = 100
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4)) # alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5)) # learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-3] learning_rate_values = [1e-3]
alpha_values = [0] # learning_rate_values = [5e-4]
alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
widths = [2**k for k in range(4, 13)] widths = [2**k for k in range(4, 13)]
depths = [1, 2, 4, 8, 16] depths = [1, 2, 4, 8, 16]
# learning_rate_values = [5e-4] # widths, depths = [128, 256], [4, 8]
batch_size_values = [256] batch_size_values = [256]
max_epochs_values = [10] max_epochs_values = [10]
seeds = list(range(21, 1992)) seeds = list(range(21, 1992))

Loading…
Cancel
Save