diff --git a/losses.py b/losses.py index 4d6943a..b5fa227 100644 --- a/losses.py +++ b/losses.py @@ -40,10 +40,12 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # print(rgb_norm) # Calculate 1D Space Norm (modulo 1 to account for circularity) - transformed_norm = circle_norm(outputs, target_outputs) * 2 + transformed_norm = circle_norm(outputs, target_outputs) # * 2 diff = torch.pow(rgb_norm - transformed_norm, 2) - N = torch.count_nonzero(rgb_norm) + N = len(outputs) + N = (N * (N - 1)) / 2 + # N = torch.count_nonzero(rgb_norm) return torch.sum(diff) / N diff --git a/model.py b/model.py index 94c5543..9b68458 100644 --- a/model.py +++ b/model.py @@ -31,7 +31,8 @@ class ColorTransformerModel(L.LightningModule): d = self.hparams.depth bias = self.hparams.bias if self.hparams.loop: - midlayers = [nn.Linear(w, w, bias=bias), t()] * d + midlayers = [] + midlayers += [nn.Linear(w, w, bias=bias), t()] * d else: midlayers = sum( [ @@ -64,8 +65,8 @@ class ColorTransformerModel(L.LightningModule): p_loss = preservation_loss( inputs, outputs, - target_inputs=rgb_tensor, - target_outputs=self.forward(rgb_tensor), + # target_inputs=rgb_tensor, + # target_outputs=self.forward(rgb_tensor), ) alpha = self.hparams.alpha diff --git a/newsearch.py b/newsearch.py index 99d5d9d..c72cb5b 100644 --- a/newsearch.py +++ b/newsearch.py @@ -27,19 +27,19 @@ learning_rate_values = [1e-3] # learning_rate_values = [5e-4] # alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised. -alpha_values = [0.9] +alpha_values = [0] # widths = [2**k for k in range(4, 13)] # depths = [1, 2, 4, 8, 16] -widths, depths = [512], [8] +widths, depths = [512], [4] batch_size_values = [256] max_epochs_values = [100] seeds = list(range(21, 1992)) optimizers = [ # "Adagrad", - # "Adam", + "Adam", # "SGD", - "AdamW", + # "AdamW", # "LBFGS", # "RAdam", # "RMSprop", @@ -80,7 +80,7 @@ python newmain.py fit \ --model.depth {d} \ --model.bias true \ --model.loop true \ ---model.transform relu \ +--model.transform tanh \ --trainer.min_epochs 10 \ --trainer.max_epochs {me} \ --trainer.log_every_n_steps 3 \