You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.5 KiB

import lightning as L
11 months ago
import torch
import torch.nn as nn
11 months ago
from torch.optim.lr_scheduler import ReduceLROnPlateau
11 months ago
from losses import circle_norm, preservation_loss # noqa: F401
from utils import RGBMYC_ANCHOR
11 months ago
class ColorTransformerModel(L.LightningModule):
def __init__(
self,
transform: str = "relu",
width: int = 128,
depth: int = 1,
bias: bool = False,
alpha: float = 0,
lr: float = 0.01,
loop: bool = False,
dropout=0.5,
):
11 months ago
super().__init__()
self.save_hyperparameters()
if self.hparams.transform.lower() == "tanh":
t = nn.Tanh
elif self.hparams.transform.lower() == "relu":
t = nn.ReLU
w = self.hparams.width
d = self.hparams.depth
bias = self.hparams.bias
if self.hparams.loop:
midlayers = []
midlayers += [nn.Linear(w, w, bias=bias), t()] * d
else:
midlayers = sum(
[
[nn.Linear(w, w, bias=bias), nn.Dropout(self.dropout), t()]
for _ in range(d)
],
[],
)
self.network = nn.Sequential(
nn.Linear(3, w, bias=bias),
t(),
*midlayers,
nn.Linear(w, 3, bias=bias),
t(),
nn.Linear(3, 1, bias=bias),
11 months ago
)
11 months ago
11 months ago
def forward(self, x):
x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2
x = torch.sigmoid(x)
return x
11 months ago
11 months ago
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
p_loss = preservation_loss(
11 months ago
inputs,
outputs,
# target_inputs=rgb_tensor,
# target_outputs=self.forward(rgb_tensor),
11 months ago
)
alpha = self.hparams.alpha
# N = len(outputs)
# distance = circle_norm(outputs, labels).mean()
distance = torch.norm(outputs - labels).mean()
# Backprop with this:
loss = (1 - alpha) * p_loss + alpha * distance
# p_loss is unsupervised (preserve relative distances - either in-batch or to-target)
# distance is supervised.
self.log("hp_metric", distance)
# Log all losses individually
self.log("train_pres", p_loss)
self.log("train_mse", distance)
self.log("train_loss", loss)
return loss
11 months ago
def validation_step(self, batch):
10 months ago
inputs, labels = batch # these are true HSV labels - no learning allowed.
outputs = self.forward(inputs)
# distance = torch.minimum(
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
# )
distance = torch.norm(outputs - labels)
10 months ago
mean_loss = torch.mean(distance)
max_loss = torch.max(distance)
self.log("val_mse", mean_loss)
self.log("val_max", max_loss)
10 months ago
return mean_loss
11 months ago
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.hparams.lr,
)
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
)
11 months ago
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
11 months ago
"monitor": "hp_metric", # Specify the metric to monitor
},
}