Browse Source

supervised questionable

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
865e7f5104
  1. 1
      .gitignore
  2. 10
      losses.py
  3. 2
      makefile
  4. 7
      model.py
  5. 2
      newsearch.py
  6. 2
      utils.py

1
.gitignore

@ -6,3 +6,4 @@ out/
*.tar.gz *.tar.gz
.pat .pat
out* out*
.lr*

10
losses.py

@ -1,6 +1,6 @@
import torch import torch
from utils import PURE_RGB from utils import RGBMYC_ANCHOR
# def smoothness_loss(outputs): # def smoothness_loss(outputs):
# # Sort outputs for smoothness calculation # # Sort outputs for smoothness calculation
@ -40,11 +40,11 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# print(rgb_norm) # print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity) # Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = circle_norm(outputs, target_outputs) transformed_norm = circle_norm(outputs, target_outputs) * 2
diff = torch.pow(rgb_norm - transformed_norm, 2) diff = torch.pow(rgb_norm - transformed_norm, 2)
N = len(outputs) N = torch.count_nonzero(rgb_norm)
return torch.sum(diff) / (N * (N - 1)) / 2 return torch.sum(diff) / N
def circle_norm(vector, other_vector): def circle_norm(vector, other_vector):
@ -68,7 +68,7 @@ def separation_loss(red, green, blue):
def calculate_separation_loss(model): def calculate_separation_loss(model):
# TODO: remove # TODO: remove
# Wrapper function to calculate separation loss # Wrapper function to calculate separation loss
outputs = model(PURE_RGB.to(model.device)) outputs = model(RGBMYC_ANCHOR.to(model.device))
red, green, blue = outputs[0], outputs[1], outputs[2] red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue) return separation_loss(red, green, blue)

2
makefile

@ -36,7 +36,7 @@ help:
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau # python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
python newmain.py fit --help python newmain.py fit --help
search: search: lint
python newsearch.py python newsearch.py
hsv: hsv:

7
model.py

@ -4,8 +4,7 @@ import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import preservation_loss from losses import preservation_loss
from utils import RGBMYC_ANCHOR
# from utils import PURE_RGB
class ColorTransformerModel(L.LightningModule): class ColorTransformerModel(L.LightningModule):
@ -49,7 +48,7 @@ class ColorTransformerModel(L.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
# rgb_tensor = PURE_RGB.to(self.device) rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
p_loss = preservation_loss( p_loss = preservation_loss(
inputs, inputs,
outputs, outputs,
@ -59,7 +58,7 @@ class ColorTransformerModel(L.LightningModule):
alpha = self.hparams.alpha alpha = self.hparams.alpha
# N = len(outputs) # N = len(outputs)
# distance = circle_norm(outputs, labels) / (N*(N-1)/2) # distance = circle_norm(outputs, labels).mean()
distance = torch.norm(outputs - labels).mean() distance = torch.norm(outputs - labels).mean()
# Backprop with this: # Backprop with this:

2
newsearch.py

@ -27,7 +27,7 @@ learning_rate_values = [1e-3]
# learning_rate_values = [5e-4] # learning_rate_values = [5e-4]
# alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised. # alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
alpha_values = [0] alpha_values = [1]
# widths = [2**k for k in range(4, 13)] # widths = [2**k for k in range(4, 13)]
# depths = [1, 2, 4, 8, 16] # depths = [1, 2, 4, 8, 16]
widths, depths = [512], [4] widths, depths = [512], [4]

2
utils.py

@ -34,7 +34,7 @@ def extract_colors():
return rgb_tensor, xkcd_color_names return rgb_tensor, xkcd_color_names
PURE_RGB = preprocess_data( RGBMYC_ANCHOR = preprocess_data(
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0) torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
) )
PURE_HSV = torch.tensor( PURE_HSV = torch.tensor(

Loading…
Cancel
Save