import torch def weighted_loss(inputs, outputs, alpha): # Calculate RGB Norm (Perceptual Difference) rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # Weighted Loss loss = alpha * rgb_norm + (1 - alpha) * transformed_norm return torch.mean(loss) def enhanced_loss(inputs, outputs, alpha, distinct_threshold): # Calculate RGB Norm rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) # Identify Distinct Colors (based on a threshold in RGB space) distinct_colors = rgb_norm > distinct_threshold # Penalty for Distinct Colors being too close in the transformed space # Here we do not take the mean yet, to avoid double averaging distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors # Combined Loss # The mean is taken here, once, after all components are combined loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty return torch.mean(loss)