From 865e7f5104ed8fb7cf2ab3e6045df0f630c3711d Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sun, 28 Jan 2024 01:25:10 +0000 Subject: [PATCH] supervised questionable --- .gitignore | 1 + losses.py | 10 +++++----- makefile | 2 +- model.py | 7 +++---- newsearch.py | 2 +- utils.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index e5742b0..28fc640 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ out/ *.tar.gz .pat out* +.lr* diff --git a/losses.py b/losses.py index 9ed7e1c..4d6943a 100644 --- a/losses.py +++ b/losses.py @@ -1,6 +1,6 @@ import torch -from utils import PURE_RGB +from utils import RGBMYC_ANCHOR # def smoothness_loss(outputs): # # Sort outputs for smoothness calculation @@ -40,11 +40,11 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # print(rgb_norm) # Calculate 1D Space Norm (modulo 1 to account for circularity) - transformed_norm = circle_norm(outputs, target_outputs) + transformed_norm = circle_norm(outputs, target_outputs) * 2 diff = torch.pow(rgb_norm - transformed_norm, 2) - N = len(outputs) - return torch.sum(diff) / (N * (N - 1)) / 2 + N = torch.count_nonzero(rgb_norm) + return torch.sum(diff) / N def circle_norm(vector, other_vector): @@ -68,7 +68,7 @@ def separation_loss(red, green, blue): def calculate_separation_loss(model): # TODO: remove # Wrapper function to calculate separation loss - outputs = model(PURE_RGB.to(model.device)) + outputs = model(RGBMYC_ANCHOR.to(model.device)) red, green, blue = outputs[0], outputs[1], outputs[2] return separation_loss(red, green, blue) diff --git a/makefile b/makefile index 9ff97b3..780943e 100644 --- a/makefile +++ b/makefile @@ -36,7 +36,7 @@ help: # python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau python newmain.py fit --help -search: +search: lint python newsearch.py hsv: diff --git a/model.py b/model.py index 74e19c6..5b1eb4c 100644 --- a/model.py +++ b/model.py @@ -4,8 +4,7 @@ import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import preservation_loss - -# from utils import PURE_RGB +from utils import RGBMYC_ANCHOR class ColorTransformerModel(L.LightningModule): @@ -49,7 +48,7 @@ class ColorTransformerModel(L.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - # rgb_tensor = PURE_RGB.to(self.device) + rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841 p_loss = preservation_loss( inputs, outputs, @@ -59,7 +58,7 @@ class ColorTransformerModel(L.LightningModule): alpha = self.hparams.alpha # N = len(outputs) - # distance = circle_norm(outputs, labels) / (N*(N-1)/2) + # distance = circle_norm(outputs, labels).mean() distance = torch.norm(outputs - labels).mean() # Backprop with this: diff --git a/newsearch.py b/newsearch.py index 9bef58f..978f59e 100644 --- a/newsearch.py +++ b/newsearch.py @@ -27,7 +27,7 @@ learning_rate_values = [1e-3] # learning_rate_values = [5e-4] # alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised. -alpha_values = [0] +alpha_values = [1] # widths = [2**k for k in range(4, 13)] # depths = [1, 2, 4, 8, 16] widths, depths = [512], [4] diff --git a/utils.py b/utils.py index e0f630c..1847b01 100644 --- a/utils.py +++ b/utils.py @@ -34,7 +34,7 @@ def extract_colors(): return rgb_tensor, xkcd_color_names -PURE_RGB = preprocess_data( +RGBMYC_ANCHOR = preprocess_data( torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0) ) PURE_HSV = torch.tensor(