|
@ -1,16 +1,15 @@ |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
# def weighted_loss(inputs, outputs, alpha): |
|
|
|
|
|
# # Calculate RGB Norm (Perceptual Difference) |
|
|
|
|
|
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|
|
|
|
|
|
|
|
def weighted_loss(inputs, outputs, alpha): |
|
|
# # Calculate 1D Space Norm |
|
|
# Calculate RGB Norm (Perceptual Difference) |
|
|
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
# Calculate 1D Space Norm |
|
|
|
|
|
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
# Weighted Loss |
|
|
# # Weighted Loss |
|
|
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm |
|
|
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm |
|
|
return torch.mean(loss) |
|
|
# return torch.mean(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): |
|
|
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): |
|
@ -33,7 +32,7 @@ def weighted_loss(inputs, outputs, alpha): |
|
|
# return torch.mean(loss) |
|
|
# return torch.mean(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhanced_loss(inputs, outputs, alpha): |
|
|
def preservation_loss(inputs, outputs): |
|
|
# Calculate RGB Norm |
|
|
# Calculate RGB Norm |
|
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|
|
|
|
|
|
|
@ -42,19 +41,19 @@ def enhanced_loss(inputs, outputs, alpha): |
|
|
|
|
|
|
|
|
# Distance Preservation Component |
|
|
# Distance Preservation Component |
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
# Encourages the model to keep relative distances from the RGB space in the transformed space |
|
|
distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) |
|
|
return torch.mean(torch.abs(rgb_norm - transformed_norm)) |
|
|
|
|
|
|
|
|
# Combined Loss |
|
|
|
|
|
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smoothness_loss(outputs): |
|
|
def smoothness_loss(outputs): |
|
|
# Sort outputs for smoothness calculation |
|
|
# Sort outputs for smoothness calculation |
|
|
sorted_outputs, _ = torch.sort(outputs, dim=0) |
|
|
sorted_outputs, _ = torch.sort(outputs, dim=0) |
|
|
|
|
|
first_elements = sorted_outputs[:2] |
|
|
|
|
|
|
|
|
|
|
|
# Concatenate the first element at the end of the sorted_outputs |
|
|
|
|
|
extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) |
|
|
|
|
|
|
|
|
# Calculate smoothness in the sorted outputs |
|
|
# Calculate smoothness in the sorted outputs |
|
|
first_derivative = torch.diff(sorted_outputs, n=1, dim=0) |
|
|
first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) |
|
|
second_derivative = torch.diff(first_derivative, n=1, dim=0) |
|
|
second_derivative = torch.diff(first_derivative, n=1, dim=0) |
|
|
smoothness_loss = torch.mean(torch.abs(second_derivative)) |
|
|
smoothness_loss = torch.mean(torch.abs(second_derivative)) |
|
|
return smoothness_loss |
|
|
return smoothness_loss |
|
|