Browse Source

remove modulo from 3-space

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
953488be4c
  1. 9
      datamodule.py
  2. 36
      losses.py
  3. 13
      model.py
  4. 6
      newsearch.py

9
datamodule.py

@ -38,10 +38,11 @@ class ColorDataModule(L.LightningDataModule):
def get_xkcd_colors(cls): def get_xkcd_colors(cls):
rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True) rgb_tensor = preprocess_data(rgb_tensor, skip=True)
return [ # return [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) # (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor)) # for i in range(len(rgb_tensor))
] # ]
return [(c, cls.get_hue(c)) for c in rgb_tensor]
def setup(self, stage: str): def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders # Assign train/val datasets for use in dataloaders

36
losses.py

@ -17,6 +17,42 @@ from utils import PURE_RGB
# return smoothness_loss # return smoothness_loss
def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None:
target_inputs = inputs
else:
assert target_outputs is not None
if target_outputs is None:
target_outputs = outputs
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
# max_rgb_distance = 1
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent.
# yes it connects black and white, but also complimentary colors to primary
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm_a = torch.triu(
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
transformed_norm_b = torch.triu(
torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b)
diff = torch.pow(rgb_norm - transformed_norm, 2)
return torch.mean(diff)
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component (or scaled euclidean if given targets) # Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space

13
model.py

@ -3,7 +3,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss # noqa: F401 from losses import ( # noqa: F401
calculate_separation_loss,
preservation_loss,
simple_preservation_loss,
)
from utils import PURE_RGB
class ColorTransformerModel(L.LightningModule): class ColorTransformerModel(L.LightningModule):
@ -18,6 +23,7 @@ class ColorTransformerModel(L.LightningModule):
): ):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
if self.hparams.transform.lower() == "tanh": if self.hparams.transform.lower() == "tanh":
t = nn.Tanh t = nn.Tanh
elif self.hparams.transform.lower() == "relu": elif self.hparams.transform.lower() == "relu":
@ -46,9 +52,12 @@ class ColorTransformerModel(L.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
p_loss = preservation_loss( rgb_tensor = PURE_RGB.to(self.device)
p_loss = simple_preservation_loss(
inputs, inputs,
outputs, outputs,
target_inputs=rgb_tensor,
target_outputs=self.forward(rgb_tensor),
) )
alpha = self.hparams.alpha alpha = self.hparams.alpha
# loss = p_loss # loss = p_loss

6
newsearch.py

@ -32,8 +32,8 @@ alpha_values = [0]
# depths = [1, 2, 4, 8, 16] # depths = [1, 2, 4, 8, 16]
widths, depths = [512], [4] widths, depths = [512], [4]
batch_size_values = [1024] batch_size_values = [256]
max_epochs_values = [250] max_epochs_values = [100]
seeds = list(range(21, 1992)) seeds = list(range(21, 1992))
optimizers = [ optimizers = [
# "Adagrad", # "Adagrad",
@ -73,7 +73,7 @@ for idx, params in enumerate(search_params):
python newmain.py fit \ python newmain.py fit \
--seed_everything {s} \ --seed_everything {s} \
--data.batch_size {bs} \ --data.batch_size {bs} \
--data.train_size 10000 \ --data.train_size 0 \
--data.val_size 10000 \ --data.val_size 10000 \
--model.alpha {a} \ --model.alpha {a} \
--model.width {w} \ --model.width {w} \

Loading…
Cancel
Save