|
|
|
import lightning as L
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
|
|
|
|
from losses import calculate_separation_loss, preservation_loss # noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
class ColorTransformerModel(L.LightningModule):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
transform: str = "relu",
|
|
|
|
width: int = 128,
|
|
|
|
depth: int = 1,
|
|
|
|
bias: bool = False,
|
|
|
|
alpha: float = 0,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.save_hyperparameters()
|
|
|
|
if self.hparams.transform.lower() == "tanh":
|
|
|
|
t = nn.Tanh
|
|
|
|
elif self.hparams.transform.lower() == "relu":
|
|
|
|
t = nn.ReLU
|
|
|
|
|
|
|
|
w = self.hparams.width
|
|
|
|
d = self.hparams.depth
|
|
|
|
bias = self.hparams.bias
|
|
|
|
midlayers = [nn.Linear(w, w, bias=bias), t()] * d
|
|
|
|
self.network = nn.Sequential(
|
|
|
|
nn.Linear(3, w, bias=bias),
|
|
|
|
t(),
|
|
|
|
*midlayers,
|
|
|
|
nn.Linear(w, 3, bias=bias),
|
|
|
|
t(),
|
|
|
|
nn.Linear(3, 1, bias=bias),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.network(x)
|
|
|
|
# Circular mapping
|
|
|
|
# x = (torch.sin(x) + 1) / 2
|
|
|
|
x = torch.sigmoid(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
|
|
|
outputs = self.forward(inputs)
|
|
|
|
p_loss = preservation_loss(
|
|
|
|
inputs,
|
|
|
|
outputs,
|
|
|
|
)
|
|
|
|
# alpha = self.hparams.alpha # TODO: decide what to do with this...
|
|
|
|
# loss = p_loss
|
|
|
|
|
|
|
|
# distance = torch.minimum(
|
|
|
|
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
|
|
|
# ).mean()
|
|
|
|
distance = torch.norm(outputs - labels).mean()
|
|
|
|
loss = p_loss
|
|
|
|
|
|
|
|
self.log("train_loss", distance)
|
|
|
|
self.log("hp_metric", loss)
|
|
|
|
self.log("p_loss", p_loss)
|
|
|
|
return distance
|
|
|
|
|
|
|
|
def validation_step(self, batch):
|
|
|
|
inputs, labels = batch # these are true HSV labels - no learning allowed.
|
|
|
|
outputs = self.forward(inputs)
|
|
|
|
distance = torch.minimum(
|
|
|
|
torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
|
|
|
)
|
|
|
|
mean_loss = torch.mean(distance)
|
|
|
|
max_loss = torch.max(distance)
|
|
|
|
self.log("val_mean_loss", mean_loss)
|
|
|
|
self.log("val_max_loss", max_loss)
|
|
|
|
return mean_loss
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = torch.optim.SGD(
|
|
|
|
self.parameters(),
|
|
|
|
lr=0.1,
|
|
|
|
)
|
|
|
|
lr_scheduler = ReduceLROnPlateau(
|
|
|
|
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
|
|
|
|
)
|
|
|
|
return {
|
|
|
|
"optimizer": optimizer,
|
|
|
|
"lr_scheduler": {
|
|
|
|
"scheduler": lr_scheduler,
|
|
|
|
"monitor": "hp_metric", # Specify the metric to monitor
|
|
|
|
},
|
|
|
|
}
|