Browse Source

refactor loss functions

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
719597ce90
  1. 25
      dataloader.py
  2. 120
      losses.py
  3. 5
      main.py
  4. 2
      makefile
  5. 8
      model.py
  6. 24
      utils.py

25
dataloader.py

@ -2,6 +2,8 @@ import matplotlib.colors as mcolors
import torch
from torch.utils.data import DataLoader, TensorDataset
from utils import preprocess_data
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
@ -32,7 +34,7 @@ def create_gray_supplement(N: int = 50):
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
def create_named_dataloader(N: int = 50, **kwargs):
def create_named_dataloader(N: int = 0, **kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
# Creating a dataset with RGB values and their corresponding color names
@ -40,29 +42,12 @@ def create_named_dataloader(N: int = 50, **kwargs):
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
dataset_with_names += create_gray_supplement(N)
if N > 0:
dataset_with_names += create_gray_supplement(N)
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names
def preprocess_data(data):
# Assuming 'data' is a tensor of shape [n_samples, 3]
# Compute argmin and argmax for each row
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
# Normalize or scale argmin and argmax if necessary
# For example, here I am just dividing by the number of features
argmin_values /= data.shape[1]
argmax_values /= data.shape[1]
# Concatenate the argmin and argmax values to the original data
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
return new_data
if __name__ == "__main__":
batch_size = 4
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)

120
losses.py

@ -1,59 +1,83 @@
import torch
# def weighted_loss(inputs, outputs, alpha):
# # Calculate RGB Norm (Perceptual Difference)
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
from utils import PURE_RGB
# # Calculate 1D Space Norm
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# def smoothness_loss(outputs):
# # Sort outputs for smoothness calculation
# sorted_outputs, _ = torch.sort(outputs, dim=0)
# first_elements = sorted_outputs[:2]
# # Weighted Loss
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm
# return torch.mean(loss)
# # Concatenate the first element at the end of the sorted_outputs
# extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
# # Calculate RGB Norm
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# # Calculate 1D Space Norm
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# # Identify Distinct Colors (based on a threshold in RGB space)
# distinct_colors = rgb_norm > distinct_threshold
# # Penalty for Distinct Colors being too close in the transformed space
# # Here we do not take the mean yet, to avoid double averaging
# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors
# # Combined Loss
# # The mean is taken here, once, after all components are combined
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty
# return torch.mean(loss)
# # Calculate smoothness in the sorted outputs
# first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
# second_derivative = torch.diff(first_derivative, n=1, dim=0)
# smoothness_loss = torch.mean(torch.abs(second_derivative))
# return smoothness_loss
def preservation_loss(inputs, outputs):
# Calculate RGB Norm
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space
return torch.mean(torch.abs(rgb_norm - transformed_norm))
def smoothness_loss(outputs):
# Sort outputs for smoothness calculation
sorted_outputs, _ = torch.sort(outputs, dim=0)
first_elements = sorted_outputs[:2]
# Concatenate the first element at the end of the sorted_outputs
extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
# Calculate smoothness in the sorted outputs
first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
second_derivative = torch.diff(first_derivative, n=1, dim=0)
smoothness_loss = torch.mean(torch.abs(second_derivative))
return smoothness_loss
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
rgb_norm = (
rgb_norm % 1
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu(
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1)
)
diff = torch.abs(rgb_norm - transformed_norm)
# print(diff)
return torch.mean(diff)
def separation_loss(red, green, blue):
# Separation Component
# Encourages the model to keep R, G, B values equally separated in the transformed space
red, green, blue = red % 1, green % 1, blue % 1
red_green_distance = torch.min(
torch.abs((red - green)), torch.abs((1 + red - green))
)
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue)))
green_blue_distance = torch.min(
torch.abs((green - blue)), torch.abs((1 + green - blue))
)
# print(red_green_distance, red_blue_distance, green_blue_distance)
# we want these distances to be equal to one another
return (
torch.abs(red_green_distance - red_blue_distance)
+ torch.abs(red_green_distance - green_blue_distance)
+ torch.abs(red_blue_distance - green_blue_distance)
)
def calculate_separation_loss(model):
# Wrapper function to calculate separation loss
outputs = model(PURE_RGB)
red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue)
if __name__ == "__main__":
# test preservation loss
# create torch vector containing pure R, G, B.
test_input = torch.tensor(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32
)
test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32)
print(preservation_loss(test_input[:3], test_output[:3]))
rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)
print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2]))

5
main.py

@ -2,7 +2,7 @@ import argparse
import pytorch_lightning as pl
from dataloader import create_named_dataloader as init_data
from dataloader import create_named_dataloader
from model import ColorTransformerModel
@ -45,7 +45,8 @@ if __name__ == "__main__":
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = init_data(
train_dataloader = create_named_dataloader(
N=0,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,

2
makefile

@ -4,7 +4,7 @@ lint:
flake8 --ignore E501,W503 .
test:
python main.py --alpha 1 --lr 1e-4 --max_epochs 500
python main.py --alpha 4 --lr 2e-4 --max_epochs 200
search:
python search.py

8
model.py

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import preservation_loss, smoothness_loss
from losses import calculate_separation_loss, preservation_loss
class ColorTransformerModel(pl.LightningModule):
@ -72,14 +72,14 @@ class ColorTransformerModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
s_loss = smoothness_loss(outputs)
s_loss = calculate_separation_loss(model=self)
p_loss = preservation_loss(
inputs,
outputs,
)
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
self.log("hp_metric", p_loss)
loss = (p_loss + alpha * s_loss) / (1 + alpha)
self.log("hp_metric", loss)
self.log("train_loss", loss)
return loss

24
utils.py

@ -0,0 +1,24 @@
import torch
def preprocess_data(data):
# Assuming 'data' is a tensor of shape [n_samples, 3]
# Compute argmin and argmax for each row
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
# Normalize or scale argmin and argmax if necessary
# For example, here I am just dividing by the number of features
argmin_values /= data.shape[1] - 1
argmax_values /= data.shape[1] - 1
# Concatenate the argmin and argmax values to the original data
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
return new_data
PURE_RGB = preprocess_data(
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
)
Loading…
Cancel
Save