import torch from utils import RGBMYC_ANCHOR # def smoothness_loss(outputs): # # Sort outputs for smoothness calculation # sorted_outputs, _ = torch.sort(outputs, dim=0) # first_elements = sorted_outputs[:2] # # Concatenate the first element at the end of the sorted_outputs # extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) # # Calculate smoothness in the sorted outputs # first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) # second_derivative = torch.diff(first_derivative, n=1, dim=0) # smoothness_loss = torch.mean(torch.abs(second_derivative)) # return smoothness_loss def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # Distance Preservation Component (or scaled euclidean if given targets) # Encourages the model to keep relative distances from the RGB space in the transformed space if target_inputs is None: target_inputs = inputs else: assert target_outputs is not None if target_outputs is None: target_outputs = outputs # Calculate RGB Norm max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] # max_rgb_distance = 1 rgb_norm = ( torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) / max_rgb_distance ) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space # rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent. # yes it connects black and white, but also complimentary colors to primary # print(rgb_norm) # Calculate 1D Space Norm (modulo 1 to account for circularity) transformed_norm = circle_norm(outputs, target_outputs) * 2 diff = torch.pow(rgb_norm - transformed_norm, 2) N = len(outputs) N = (N * (N - 1)) / 2 # N = torch.count_nonzero(rgb_norm) return torch.sum(diff) / N def circle_norm(vector, other_vector): # Assumes vectors are of shape (N,1) loss_a = torch.triu(torch.abs((vector - other_vector.T))) loss_b = torch.triu(1 - torch.abs((vector - other_vector.T))) loss = torch.minimum(loss_a, loss_b) return loss def separation_loss(red, green, blue): # Separation Component # TODO: remove # Encourages the model to keep R, G, B values equally separated in the transformed space red_loss = torch.abs(0 - red) green_loss = torch.abs(1 / 3 - green) / (2 / 3) blue_loss = torch.abs(2 / 3 - blue) / (2 / 3) return red_loss + green_loss + blue_loss def calculate_separation_loss(model): # TODO: remove # Wrapper function to calculate separation loss outputs = model(RGBMYC_ANCHOR.to(model.device)) red, green, blue = outputs[0], outputs[1], outputs[2] return separation_loss(red, green, blue) if __name__ == "__main__": # test preservation loss # create torch vector containing pure R, G, B. test_input = torch.tensor( [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32 ) test_output = torch.tensor([[0], [1 / 3], [2 / 3], [0], [0]], dtype=torch.float32) print(preservation_loss(test_input[:3], test_output[:3])) rgb = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32) print(separation_loss(red=rgb[0], green=rgb[1], blue=rgb[2]))