|
|
@ -44,6 +44,12 @@ def enhanced_loss(inputs, outputs, alpha): |
|
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
|
distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) |
|
|
|
|
|
|
|
# Combined Loss |
|
|
|
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def smoothness_loss(outputs): |
|
|
|
# Sort outputs for smoothness calculation |
|
|
|
sorted_outputs, _ = torch.sort(outputs, dim=0) |
|
|
|
|
|
|
@ -51,7 +57,4 @@ def enhanced_loss(inputs, outputs, alpha): |
|
|
|
first_derivative = torch.diff(sorted_outputs, n=1, dim=0) |
|
|
|
second_derivative = torch.diff(first_derivative, n=1, dim=0) |
|
|
|
smoothness_loss = torch.mean(torch.abs(second_derivative)) |
|
|
|
|
|
|
|
# Combined Loss |
|
|
|
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss |
|
|
|
return loss |
|
|
|
return smoothness_loss |
|
|
|