diff --git a/losses.py b/losses.py index a72b7f5..2587538 100644 --- a/losses.py +++ b/losses.py @@ -46,21 +46,10 @@ def preservation_loss(inputs, outputs): def separation_loss(red, green, blue): # Separation Component # Encourages the model to keep R, G, B values equally separated in the transformed space - red, green, blue = red % 1, green % 1, blue % 1 - red_green_distance = torch.min( - torch.abs((red - green)), torch.abs((1 + red - green)) - ) - red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue))) - green_blue_distance = torch.min( - torch.abs((green - blue)), torch.abs((1 + green - blue)) - ) - # print(red_green_distance, red_blue_distance, green_blue_distance) - # we want these distances to be equal to one another - return ( - torch.abs(red_green_distance - red_blue_distance) - + torch.abs(red_green_distance - green_blue_distance) - + torch.abs(red_blue_distance - green_blue_distance) - ) + red_loss = torch.abs(0 - red) + green_loss = torch.abs(1 / 3 - green) / (2 / 3) + blue_loss = torch.abs(2 / 3 - blue) / (2 / 3) + return red_loss + green_loss + blue_loss def calculate_separation_loss(model):