diff --git a/dataloader.py b/dataloader.py index 0f12bc2..9551ea1 100644 --- a/dataloader.py +++ b/dataloader.py @@ -2,6 +2,8 @@ import matplotlib.colors as mcolors import torch from torch.utils.data import DataLoader, TensorDataset +from utils import preprocess_data + def extract_colors(): # Extracting the list of xkcd colors as RGB triples @@ -32,7 +34,7 @@ def create_gray_supplement(N: int = 50): return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] -def create_named_dataloader(N: int = 50, **kwargs): +def create_named_dataloader(N: int = 0, **kwargs): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset with RGB values and their corresponding color names @@ -40,29 +42,12 @@ def create_named_dataloader(N: int = 50, **kwargs): (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] - dataset_with_names += create_gray_supplement(N) + if N > 0: + dataset_with_names += create_gray_supplement(N) train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) return train_dataloader_with_names -def preprocess_data(data): - # Assuming 'data' is a tensor of shape [n_samples, 3] - - # Compute argmin and argmax for each row - argmin_values = torch.argmin(data, dim=1, keepdim=True).float() - argmax_values = torch.argmax(data, dim=1, keepdim=True).float() - - # Normalize or scale argmin and argmax if necessary - # For example, here I am just dividing by the number of features - argmin_values /= data.shape[1] - argmax_values /= data.shape[1] - - # Concatenate the argmin and argmax values to the original data - new_data = torch.cat((data, argmin_values, argmax_values), dim=1) - - return new_data - - if __name__ == "__main__": batch_size = 4 train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) diff --git a/losses.py b/losses.py index 7d15f07..a72b7f5 100644 --- a/losses.py +++ b/losses.py @@ -1,59 +1,83 @@ import torch -# def weighted_loss(inputs, outputs, alpha): -# # Calculate RGB Norm (Perceptual Difference) -# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) +from utils import PURE_RGB -# # Calculate 1D Space Norm -# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) +# def smoothness_loss(outputs): +# # Sort outputs for smoothness calculation +# sorted_outputs, _ = torch.sort(outputs, dim=0) +# first_elements = sorted_outputs[:2] -# # Weighted Loss -# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm -# return torch.mean(loss) +# # Concatenate the first element at the end of the sorted_outputs +# extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) - -# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): -# # Calculate RGB Norm -# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) - -# # Calculate 1D Space Norm -# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) - -# # Identify Distinct Colors (based on a threshold in RGB space) -# distinct_colors = rgb_norm > distinct_threshold - -# # Penalty for Distinct Colors being too close in the transformed space -# # Here we do not take the mean yet, to avoid double averaging -# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors - -# # Combined Loss -# # The mean is taken here, once, after all components are combined -# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty -# return torch.mean(loss) +# # Calculate smoothness in the sorted outputs +# first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) +# second_derivative = torch.diff(first_derivative, n=1, dim=0) +# smoothness_loss = torch.mean(torch.abs(second_derivative)) +# return smoothness_loss def preservation_loss(inputs, outputs): - # Calculate RGB Norm - rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) - - # Calculate 1D Space Norm - transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) - # Distance Preservation Component # Encourages the model to keep relative distances from the RGB space in the transformed space - return torch.mean(torch.abs(rgb_norm - transformed_norm)) - -def smoothness_loss(outputs): - # Sort outputs for smoothness calculation - sorted_outputs, _ = torch.sort(outputs, dim=0) - first_elements = sorted_outputs[:2] - - # Concatenate the first element at the end of the sorted_outputs - extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) - - # Calculate smoothness in the sorted outputs - first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) - second_derivative = torch.diff(first_derivative, n=1, dim=0) - smoothness_loss = torch.mean(torch.abs(second_derivative)) - return smoothness_loss + # Calculate RGB Norm + max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] + rgb_norm = ( + torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)) + / max_rgb_distance + ) + rgb_norm = ( + rgb_norm % 1 + ) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space + # print(rgb_norm) + + # Calculate 1D Space Norm (modulo 1 to account for circularity) + transformed_norm = torch.triu( + torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1) + ) + + diff = torch.abs(rgb_norm - transformed_norm) + # print(diff) + + return torch.mean(diff) + + +def separation_loss(red, green, blue): + # Separation Component + # Encourages the model to keep R, G, B values equally separated in the transformed space + red, green, blue = red % 1, green % 1, blue % 1 + red_green_distance = torch.min( + torch.abs((red - green)), torch.abs((1 + red - green)) + ) + red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue))) + green_blue_distance = torch.min( + torch.abs((green - blue)), torch.abs((1 + green - blue)) + ) + # print(red_green_distance, red_blue_distance, green_blue_distance) + # we want these distances to be equal to one another + return ( + torch.abs(red_green_distance - red_blue_distance) + + torch.abs(red_green_distance - green_blue_distance) + + torch.abs(red_blue_distance - green_blue_distance) + ) + + +def calculate_separation_loss(model): + # Wrapper function to calculate separation loss + outputs = model(PURE_RGB) + red, green, blue = outputs[0], outputs[1], outputs[2] + return separation_loss(red, green, blue) + + +if __name__ == "__main__": + # test preservation loss + # create torch vector containing pure R, G, B. + test_input = torch.tensor( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32 + ) + test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32) + + print(preservation_loss(test_input[:3], test_output[:3])) + rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32) + print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2])) diff --git a/main.py b/main.py index 1d59ddf..a40256c 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ import argparse import pytorch_lightning as pl -from dataloader import create_named_dataloader as init_data +from dataloader import create_named_dataloader from model import ColorTransformerModel @@ -45,7 +45,8 @@ if __name__ == "__main__": # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed - train_dataloader = init_data( + train_dataloader = create_named_dataloader( + N=0, batch_size=args.bs, shuffle=True, num_workers=args.num_workers, diff --git a/makefile b/makefile index 3c7e957..6e25865 100644 --- a/makefile +++ b/makefile @@ -4,7 +4,7 @@ lint: flake8 --ignore E501,W503 . test: - python main.py --alpha 1 --lr 1e-4 --max_epochs 500 + python main.py --alpha 4 --lr 2e-4 --max_epochs 200 search: python search.py diff --git a/model.py b/model.py index b07bc25..9d87202 100644 --- a/model.py +++ b/model.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import preservation_loss, smoothness_loss +from losses import calculate_separation_loss, preservation_loss class ColorTransformerModel(pl.LightningModule): @@ -72,14 +72,14 @@ class ColorTransformerModel(pl.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - s_loss = smoothness_loss(outputs) + s_loss = calculate_separation_loss(model=self) p_loss = preservation_loss( inputs, outputs, ) alpha = self.hparams.alpha - loss = p_loss + alpha * s_loss - self.log("hp_metric", p_loss) + loss = (p_loss + alpha * s_loss) / (1 + alpha) + self.log("hp_metric", loss) self.log("train_loss", loss) return loss diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..0c41784 --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +import torch + + +def preprocess_data(data): + # Assuming 'data' is a tensor of shape [n_samples, 3] + + # Compute argmin and argmax for each row + argmin_values = torch.argmin(data, dim=1, keepdim=True).float() + argmax_values = torch.argmax(data, dim=1, keepdim=True).float() + + # Normalize or scale argmin and argmax if necessary + # For example, here I am just dividing by the number of features + argmin_values /= data.shape[1] - 1 + argmax_values /= data.shape[1] - 1 + + # Concatenate the argmin and argmax values to the original data + new_data = torch.cat((data, argmin_values, argmax_values), dim=1) + + return new_data + + +PURE_RGB = preprocess_data( + torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32) +)