diff --git a/datamodule.py b/datamodule.py index f35c74c..b651e40 100644 --- a/datamodule.py +++ b/datamodule.py @@ -38,10 +38,11 @@ class ColorDataModule(L.LightningDataModule): def get_xkcd_colors(cls): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor, skip=True) - return [ - (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) - for i in range(len(rgb_tensor)) - ] + # return [ + # (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) + # for i in range(len(rgb_tensor)) + # ] + return [(c, cls.get_hue(c)) for c in rgb_tensor] def setup(self, stage: str): # Assign train/val datasets for use in dataloaders diff --git a/losses.py b/losses.py index eb0de91..7d34ba0 100644 --- a/losses.py +++ b/losses.py @@ -17,6 +17,42 @@ from utils import PURE_RGB # return smoothness_loss +def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): + # Distance Preservation Component (or scaled euclidean if given targets) + # Encourages the model to keep relative distances from the RGB space in the transformed space + if target_inputs is None: + target_inputs = inputs + else: + assert target_outputs is not None + if target_outputs is None: + target_outputs = outputs + + # Calculate RGB Norm + max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] + # max_rgb_distance = 1 + rgb_norm = ( + torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) + / max_rgb_distance + ) + # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space + # rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent. + # yes it connects black and white, but also complimentary colors to primary + # print(rgb_norm) + + # Calculate 1D Space Norm (modulo 1 to account for circularity) + transformed_norm_a = torch.triu( + torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) + ) + transformed_norm_b = torch.triu( + torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) + ) + transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b) + + diff = torch.pow(rgb_norm - transformed_norm, 2) + + return torch.mean(diff) + + def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # Distance Preservation Component (or scaled euclidean if given targets) # Encourages the model to keep relative distances from the RGB space in the transformed space diff --git a/model.py b/model.py index 97909ba..ea0a343 100644 --- a/model.py +++ b/model.py @@ -3,7 +3,12 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import calculate_separation_loss, preservation_loss # noqa: F401 +from losses import ( # noqa: F401 + calculate_separation_loss, + preservation_loss, + simple_preservation_loss, +) +from utils import PURE_RGB class ColorTransformerModel(L.LightningModule): @@ -18,6 +23,7 @@ class ColorTransformerModel(L.LightningModule): ): super().__init__() self.save_hyperparameters() + if self.hparams.transform.lower() == "tanh": t = nn.Tanh elif self.hparams.transform.lower() == "relu": @@ -46,9 +52,12 @@ class ColorTransformerModel(L.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - p_loss = preservation_loss( + rgb_tensor = PURE_RGB.to(self.device) + p_loss = simple_preservation_loss( inputs, outputs, + target_inputs=rgb_tensor, + target_outputs=self.forward(rgb_tensor), ) alpha = self.hparams.alpha # loss = p_loss diff --git a/newsearch.py b/newsearch.py index 13f83f9..9bef58f 100644 --- a/newsearch.py +++ b/newsearch.py @@ -32,8 +32,8 @@ alpha_values = [0] # depths = [1, 2, 4, 8, 16] widths, depths = [512], [4] -batch_size_values = [1024] -max_epochs_values = [250] +batch_size_values = [256] +max_epochs_values = [100] seeds = list(range(21, 1992)) optimizers = [ # "Adagrad", @@ -73,7 +73,7 @@ for idx, params in enumerate(search_params): python newmain.py fit \ --seed_everything {s} \ --data.batch_size {bs} \ ---data.train_size 10000 \ +--data.train_size 0 \ --data.val_size 10000 \ --model.alpha {a} \ --model.width {w} \