You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
1.0 KiB
32 lines
1.0 KiB
11 months ago
|
import torch
|
||
|
|
||
|
|
||
|
def weighted_loss(inputs, outputs, alpha):
|
||
|
# Calculate RGB Norm (Perceptual Difference)
|
||
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
|
||
|
|
||
|
# Calculate 1D Space Norm
|
||
|
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
|
||
|
|
||
|
# Weighted Loss
|
||
|
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm
|
||
|
return torch.mean(loss)
|
||
|
|
||
|
|
||
|
def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
|
||
|
# Calculate RGB Norm
|
||
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
|
||
|
|
||
|
# Calculate 1D Space Norm
|
||
|
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
|
||
|
|
||
|
# Identify Distinct Colors (based on a threshold in RGB space)
|
||
|
distinct_colors = rgb_norm > distinct_threshold
|
||
|
|
||
|
# Penalty for Distinct Colors being too close in the transformed space
|
||
|
distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors)
|
||
|
|
||
|
# Combined Loss
|
||
|
loss = alpha * rgb_norm + transformed_norm + distinct_penalty
|
||
|
return torch.mean(loss)
|