import torch def weighted_loss(inputs, outputs, alpha): # Calculate RGB Norm (Perceptual Difference) rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # Weighted Loss loss = alpha * rgb_norm + (1 - alpha) * transformed_norm return torch.mean(loss) # def enhanced_loss(inputs, outputs, alpha, distinct_threshold): # # Calculate RGB Norm # rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # # Calculate 1D Space Norm # transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # # Identify Distinct Colors (based on a threshold in RGB space) # distinct_colors = rgb_norm > distinct_threshold # # Penalty for Distinct Colors being too close in the transformed space # # Here we do not take the mean yet, to avoid double averaging # distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors # # Combined Loss # # The mean is taken here, once, after all components are combined # loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty # return torch.mean(loss) def enhanced_loss(inputs, outputs, alpha): # Calculate RGB Norm rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # Distance Preservation Component # Encourages the model to keep relative distances from the RGB space in the transformed space distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) # Combined Loss loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) return loss def smoothness_loss(outputs): # Sort outputs for smoothness calculation sorted_outputs, _ = torch.sort(outputs, dim=0) # Calculate smoothness in the sorted outputs first_derivative = torch.diff(sorted_outputs, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0) smoothness_loss = torch.mean(torch.abs(second_derivative)) return smoothness_loss