diff --git a/losses.py b/losses.py index 98badd7..f237949 100644 --- a/losses.py +++ b/losses.py @@ -44,6 +44,12 @@ def enhanced_loss(inputs, outputs, alpha): # Encourages the model to keep relative distances from the RGB space in the transformed space distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) + # Combined Loss + loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) + return loss + + +def smoothness_loss(outputs): # Sort outputs for smoothness calculation sorted_outputs, _ = torch.sort(outputs, dim=0) @@ -51,7 +57,4 @@ def enhanced_loss(inputs, outputs, alpha): first_derivative = torch.diff(sorted_outputs, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0) smoothness_loss = torch.mean(torch.abs(second_derivative)) - - # Combined Loss - loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss - return loss + return smoothness_loss