|
@ -4,7 +4,6 @@ import torch.nn as nn |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
|
|
|
|
|
|
from losses import calculate_separation_loss, preservation_loss # noqa: F401 |
|
|
from losses import calculate_separation_loss, preservation_loss # noqa: F401 |
|
|
from utils import PURE_HSV, PURE_RGB |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ColorTransformerModel(L.LightningModule): |
|
|
class ColorTransformerModel(L.LightningModule): |
|
@ -46,23 +45,14 @@ class ColorTransformerModel(L.LightningModule): |
|
|
def training_step(self, batch, batch_idx): |
|
|
def training_step(self, batch, batch_idx): |
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings |
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings |
|
|
outputs = self.forward(inputs) |
|
|
outputs = self.forward(inputs) |
|
|
# s_loss = calculate_separation_loss(model=self) |
|
|
|
|
|
# preserve distance to pure R, G, B. this acts kind of like labeled data. |
|
|
|
|
|
s_loss = preservation_loss( |
|
|
|
|
|
inputs, |
|
|
|
|
|
outputs, |
|
|
|
|
|
target_inputs=PURE_RGB, |
|
|
|
|
|
target_outputs=PURE_HSV, |
|
|
|
|
|
) |
|
|
|
|
|
p_loss = preservation_loss( |
|
|
p_loss = preservation_loss( |
|
|
inputs, |
|
|
inputs, |
|
|
outputs, |
|
|
outputs, |
|
|
) |
|
|
) |
|
|
alpha = self.hparams.alpha |
|
|
# alpha = self.hparams.alpha # TODO: decide what to do with this... |
|
|
loss = p_loss + alpha * s_loss |
|
|
loss = p_loss |
|
|
self.log("hp_metric", loss) |
|
|
self.log("hp_metric", loss) |
|
|
self.log("p_loss", p_loss) |
|
|
self.log("p_loss", p_loss) |
|
|
self.log("s_loss", s_loss) |
|
|
|
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|
def validation_step(self, batch): |
|
|
def validation_step(self, batch): |
|
|