You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

81 lines
2.4 KiB

import lightning as L
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss # noqa: F401
from utils import PURE_HSV, PURE_RGB
class ColorTransformerModel(L.LightningModule):
def __init__(
self,
transform: str = "relu",
width: int = 128,
depth: int = 1,
bias: bool = False,
):
super().__init__()
self.save_hyperparameters()
if self.hparams.transform.lower() == "tanh":
t = nn.Tanh
elif self.hparams.transform.lower() == "relu":
t = nn.ReLU
w = self.hparams.width
d = self.hparams.depth
bias = self.hparams.bias
midlayers = [nn.Linear(w, w, bias=bias), t()] * d
self.network = nn.Sequential(
nn.Linear(3, w, bias=bias),
t(),
*midlayers,
nn.Linear(w, 3, bias=bias),
t(),
nn.Linear(3, 1, bias=bias),
)
def forward(self, x):
x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2
x = torch.sigmoid(x)
return x
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# s_loss = calculate_separation_loss(model=self)
# preserve distance to pure R, G, B. this acts kind of like labeled data.
s_loss = preservation_loss(
inputs,
outputs,
target_inputs=PURE_RGB,
target_outputs=PURE_HSV,
)
p_loss = preservation_loss(
inputs,
outputs,
)
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
self.log("hp_metric", loss)
self.log("p_loss", p_loss)
self.log("s_loss", s_loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.hparams.learning_rate,
)
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "hp_metric", # Specify the metric to monitor
},
}