diff --git a/losses.py b/losses.py index 7d34ba0..9ed7e1c 100644 --- a/losses.py +++ b/losses.py @@ -17,7 +17,7 @@ from utils import PURE_RGB # return smoothness_loss -def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): +def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # Distance Preservation Component (or scaled euclidean if given targets) # Encourages the model to keep relative distances from the RGB space in the transformed space if target_inputs is None: @@ -40,49 +40,19 @@ def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs # print(rgb_norm) # Calculate 1D Space Norm (modulo 1 to account for circularity) - transformed_norm_a = torch.triu( - torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) - ) - transformed_norm_b = torch.triu( - torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) - ) - transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b) + transformed_norm = circle_norm(outputs, target_outputs) diff = torch.pow(rgb_norm - transformed_norm, 2) + N = len(outputs) + return torch.sum(diff) / (N * (N - 1)) / 2 - return torch.mean(diff) - - -def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): - # Distance Preservation Component (or scaled euclidean if given targets) - # Encourages the model to keep relative distances from the RGB space in the transformed space - if target_inputs is None: - target_inputs = inputs - else: - assert target_outputs is not None - if target_outputs is None: - target_outputs = outputs - - # Calculate RGB Norm - max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] - # max_rgb_distance = 1 - rgb_norm = ( - torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) - / max_rgb_distance - ) - # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space - rgb_norm = rgb_norm % 1 - # print(rgb_norm) - - # Calculate 1D Space Norm (modulo 1 to account for circularity) - transformed_norm = torch.triu( - torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) - ) - - diff = torch.abs(rgb_norm - transformed_norm) - # print(diff) - return torch.mean(diff) +def circle_norm(vector, other_vector): + # Assumes vectors are of shape (N,1) + loss_a = torch.triu(torch.abs((vector - other_vector.T))) + loss_b = torch.triu(1 - torch.abs((vector - other_vector.T))) + loss = torch.minimum(loss_a, loss_b) + return loss def separation_loss(red, green, blue): diff --git a/model.py b/model.py index ea0a343..74e19c6 100644 --- a/model.py +++ b/model.py @@ -3,12 +3,9 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import ( # noqa: F401 - calculate_separation_loss, - preservation_loss, - simple_preservation_loss, -) -from utils import PURE_RGB +from losses import preservation_loss + +# from utils import PURE_RGB class ColorTransformerModel(L.LightningModule): @@ -52,24 +49,22 @@ class ColorTransformerModel(L.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - rgb_tensor = PURE_RGB.to(self.device) - p_loss = simple_preservation_loss( + # rgb_tensor = PURE_RGB.to(self.device) + p_loss = preservation_loss( inputs, outputs, - target_inputs=rgb_tensor, - target_outputs=self.forward(rgb_tensor), + # target_inputs=rgb_tensor, + # target_outputs=self.forward(rgb_tensor), ) alpha = self.hparams.alpha - # loss = p_loss - # distance = torch.minimum( - # torch.norm(outputs - labels), torch.norm(1 + outputs - labels) - # ).mean() + # N = len(outputs) + # distance = circle_norm(outputs, labels) / (N*(N-1)/2) distance = torch.norm(outputs - labels).mean() # Backprop with this: loss = (1 - alpha) * p_loss + alpha * distance - # p_loss is unsupervised + # p_loss is unsupervised (preserve relative distances - either in-batch or to-target) # distance is supervised. self.log("hp_metric", loss)