Browse Source

split smoothness function out

new-sep-loss
mm 11 months ago
parent
commit
b4c5ddb886
  1. 11
      losses.py

11
losses.py

@ -44,6 +44,12 @@ def enhanced_loss(inputs, outputs, alpha):
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space
distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm))
# Combined Loss
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs)
return loss
def smoothness_loss(outputs):
# Sort outputs for smoothness calculation # Sort outputs for smoothness calculation
sorted_outputs, _ = torch.sort(outputs, dim=0) sorted_outputs, _ = torch.sort(outputs, dim=0)
@ -51,7 +57,4 @@ def enhanced_loss(inputs, outputs, alpha):
first_derivative = torch.diff(sorted_outputs, n=1, dim=0) first_derivative = torch.diff(sorted_outputs, n=1, dim=0)
second_derivative = torch.diff(first_derivative, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0)
smoothness_loss = torch.mean(torch.abs(second_derivative)) smoothness_loss = torch.mean(torch.abs(second_derivative))
return smoothness_loss
# Combined Loss
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss
return loss

Loading…
Cancel
Save