From b4c5ddb8866c32eab9983bf8d90187d50eca467a Mon Sep 17 00:00:00 2001 From: mm Date: Sat, 30 Dec 2023 07:19:46 +0000 Subject: [PATCH] split smoothness function out --- losses.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/losses.py b/losses.py index 98badd7..f237949 100644 --- a/losses.py +++ b/losses.py @@ -44,6 +44,12 @@ def enhanced_loss(inputs, outputs, alpha): # Encourages the model to keep relative distances from the RGB space in the transformed space distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) + # Combined Loss + loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) + return loss + + +def smoothness_loss(outputs): # Sort outputs for smoothness calculation sorted_outputs, _ = torch.sort(outputs, dim=0) @@ -51,7 +57,4 @@ def enhanced_loss(inputs, outputs, alpha): first_derivative = torch.diff(sorted_outputs, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0) smoothness_loss = torch.mean(torch.abs(second_derivative)) - - # Combined Loss - loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss - return loss + return smoothness_loss