You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.2 KiB

11 months ago
import torch
11 months ago
# def weighted_loss(inputs, outputs, alpha):
# # Calculate RGB Norm (Perceptual Difference)
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
11 months ago
11 months ago
# # Calculate 1D Space Norm
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
11 months ago
11 months ago
# # Weighted Loss
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm
# return torch.mean(loss)
11 months ago
11 months ago
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
# # Calculate RGB Norm
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# # Calculate 1D Space Norm
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# # Identify Distinct Colors (based on a threshold in RGB space)
# distinct_colors = rgb_norm > distinct_threshold
# # Penalty for Distinct Colors being too close in the transformed space
# # Here we do not take the mean yet, to avoid double averaging
# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors
# # Combined Loss
# # The mean is taken here, once, after all components are combined
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty
# return torch.mean(loss)
11 months ago
def preservation_loss(inputs, outputs):
11 months ago
# Calculate RGB Norm
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
11 months ago
# Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space
11 months ago
return torch.mean(torch.abs(rgb_norm - transformed_norm))
def smoothness_loss(outputs):
11 months ago
# Sort outputs for smoothness calculation
sorted_outputs, _ = torch.sort(outputs, dim=0)
11 months ago
first_elements = sorted_outputs[:2]
# Concatenate the first element at the end of the sorted_outputs
extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
11 months ago
# Calculate smoothness in the sorted outputs
11 months ago
first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
11 months ago
second_derivative = torch.diff(first_derivative, n=1, dim=0)
smoothness_loss = torch.mean(torch.abs(second_derivative))
return smoothness_loss