Browse Source

tracked down the losses bug

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
b6d9f94d8e
  1. 50
      losses.py
  2. 25
      model.py

50
losses.py

@ -17,7 +17,7 @@ from utils import PURE_RGB
# return smoothness_loss # return smoothness_loss
def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component (or scaled euclidean if given targets) # Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None: if target_inputs is None:
@ -40,49 +40,19 @@ def simple_preservation_loss(inputs, outputs, target_inputs=None, target_outputs
# print(rgb_norm) # print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity) # Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm_a = torch.triu( transformed_norm = circle_norm(outputs, target_outputs)
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
transformed_norm_b = torch.triu(
torch.norm((1 + outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
transformed_norm = torch.minimum(transformed_norm_a, transformed_norm_b)
diff = torch.pow(rgb_norm - transformed_norm, 2) diff = torch.pow(rgb_norm - transformed_norm, 2)
N = len(outputs)
return torch.sum(diff) / (N * (N - 1)) / 2
return torch.mean(diff)
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None:
target_inputs = inputs
else:
assert target_outputs is not None
if target_outputs is None:
target_outputs = outputs
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
# max_rgb_distance = 1
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
rgb_norm = rgb_norm % 1
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu(
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
diff = torch.abs(rgb_norm - transformed_norm)
# print(diff)
return torch.mean(diff) def circle_norm(vector, other_vector):
# Assumes vectors are of shape (N,1)
loss_a = torch.triu(torch.abs((vector - other_vector.T)))
loss_b = torch.triu(1 - torch.abs((vector - other_vector.T)))
loss = torch.minimum(loss_a, loss_b)
return loss
def separation_loss(red, green, blue): def separation_loss(red, green, blue):

25
model.py

@ -3,12 +3,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import ( # noqa: F401 from losses import preservation_loss
calculate_separation_loss,
preservation_loss, # from utils import PURE_RGB
simple_preservation_loss,
)
from utils import PURE_RGB
class ColorTransformerModel(L.LightningModule): class ColorTransformerModel(L.LightningModule):
@ -52,24 +49,22 @@ class ColorTransformerModel(L.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
rgb_tensor = PURE_RGB.to(self.device) # rgb_tensor = PURE_RGB.to(self.device)
p_loss = simple_preservation_loss( p_loss = preservation_loss(
inputs, inputs,
outputs, outputs,
target_inputs=rgb_tensor, # target_inputs=rgb_tensor,
target_outputs=self.forward(rgb_tensor), # target_outputs=self.forward(rgb_tensor),
) )
alpha = self.hparams.alpha alpha = self.hparams.alpha
# loss = p_loss
# distance = torch.minimum( # N = len(outputs)
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels) # distance = circle_norm(outputs, labels) / (N*(N-1)/2)
# ).mean()
distance = torch.norm(outputs - labels).mean() distance = torch.norm(outputs - labels).mean()
# Backprop with this: # Backprop with this:
loss = (1 - alpha) * p_loss + alpha * distance loss = (1 - alpha) * p_loss + alpha * distance
# p_loss is unsupervised # p_loss is unsupervised (preserve relative distances - either in-batch or to-target)
# distance is supervised. # distance is supervised.
self.log("hp_metric", loss) self.log("hp_metric", loss)

Loading…
Cancel
Save