|
|
@ -46,21 +46,10 @@ def preservation_loss(inputs, outputs): |
|
|
|
def separation_loss(red, green, blue): |
|
|
|
# Separation Component |
|
|
|
# Encourages the model to keep R, G, B values equally separated in the transformed space |
|
|
|
red, green, blue = red % 1, green % 1, blue % 1 |
|
|
|
red_green_distance = torch.min( |
|
|
|
torch.abs((red - green)), torch.abs((1 + red - green)) |
|
|
|
) |
|
|
|
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue))) |
|
|
|
green_blue_distance = torch.min( |
|
|
|
torch.abs((green - blue)), torch.abs((1 + green - blue)) |
|
|
|
) |
|
|
|
# print(red_green_distance, red_blue_distance, green_blue_distance) |
|
|
|
# we want these distances to be equal to one another |
|
|
|
return ( |
|
|
|
torch.abs(red_green_distance - red_blue_distance) |
|
|
|
+ torch.abs(red_green_distance - green_blue_distance) |
|
|
|
+ torch.abs(red_blue_distance - green_blue_distance) |
|
|
|
) |
|
|
|
red_loss = torch.abs(0 - red) |
|
|
|
green_loss = torch.abs(1 / 3 - green) / (2 / 3) |
|
|
|
blue_loss = torch.abs(2 / 3 - blue) / (2 / 3) |
|
|
|
return red_loss + green_loss + blue_loss |
|
|
|
|
|
|
|
|
|
|
|
def calculate_separation_loss(model): |
|
|
|