import torch # def weighted_loss(inputs, outputs, alpha): # # Calculate RGB Norm (Perceptual Difference) # rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # # Calculate 1D Space Norm # transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # # Weighted Loss # loss = alpha * rgb_norm + (1 - alpha) * transformed_norm # return torch.mean(loss) # def enhanced_loss(inputs, outputs, alpha, distinct_threshold): # # Calculate RGB Norm # rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # # Calculate 1D Space Norm # transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # # Identify Distinct Colors (based on a threshold in RGB space) # distinct_colors = rgb_norm > distinct_threshold # # Penalty for Distinct Colors being too close in the transformed space # # Here we do not take the mean yet, to avoid double averaging # distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors # # Combined Loss # # The mean is taken here, once, after all components are combined # loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty # return torch.mean(loss) def preservation_loss(inputs, outputs): # Calculate RGB Norm rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # Distance Preservation Component # Encourages the model to keep relative distances from the RGB space in the transformed space return torch.mean(torch.abs(rgb_norm - transformed_norm)) def smoothness_loss(outputs): # Sort outputs for smoothness calculation sorted_outputs, _ = torch.sort(outputs, dim=0) first_elements = sorted_outputs[:2] # Concatenate the first element at the end of the sorted_outputs extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) # Calculate smoothness in the sorted outputs first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0) smoothness_loss = torch.mean(torch.abs(second_derivative)) return smoothness_loss