import lightning as L import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import preservation_loss # from utils import PURE_RGB class ColorTransformerModel(L.LightningModule): def __init__( self, transform: str = "relu", width: int = 128, depth: int = 1, bias: bool = False, alpha: float = 0, lr: float = 0.01, ): super().__init__() self.save_hyperparameters() if self.hparams.transform.lower() == "tanh": t = nn.Tanh elif self.hparams.transform.lower() == "relu": t = nn.ReLU w = self.hparams.width d = self.hparams.depth bias = self.hparams.bias midlayers = [nn.Linear(w, w, bias=bias), t()] * d self.network = nn.Sequential( nn.Linear(3, w, bias=bias), t(), *midlayers, nn.Linear(w, 3, bias=bias), t(), nn.Linear(3, 1, bias=bias), ) def forward(self, x): x = self.network(x) # Circular mapping # x = (torch.sin(x) + 1) / 2 x = torch.sigmoid(x) return x def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) # rgb_tensor = PURE_RGB.to(self.device) p_loss = preservation_loss( inputs, outputs, # target_inputs=rgb_tensor, # target_outputs=self.forward(rgb_tensor), ) alpha = self.hparams.alpha # N = len(outputs) # distance = circle_norm(outputs, labels) / (N*(N-1)/2) distance = torch.norm(outputs - labels).mean() # Backprop with this: loss = (1 - alpha) * p_loss + alpha * distance # p_loss is unsupervised (preserve relative distances - either in-batch or to-target) # distance is supervised. self.log("hp_metric", loss) # Log all losses individually self.log("train_pres", p_loss) self.log("train_mse", distance) return loss def validation_step(self, batch): inputs, labels = batch # these are true HSV labels - no learning allowed. outputs = self.forward(inputs) # distance = torch.minimum( # torch.norm(outputs - labels), torch.norm(1 + outputs - labels) # ) distance = torch.norm(outputs - labels) mean_loss = torch.mean(distance) max_loss = torch.max(distance) self.log("val_mse", mean_loss) self.log("val_max", max_loss) return mean_loss def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), lr=self.hparams.lr, ) lr_scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "monitor": "hp_metric", # Specify the metric to monitor }, }