import pytorch_lightning as pl import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import calculate_separation_loss, preservation_loss # class ColorTransformerModel(pl.LightningModule): # def __init__(self, params): # super().__init__() # self.save_hyperparameters(params) # # Model layers # self.layers = nn.Sequential( # nn.Linear(5, 128, bias=False), # nn.Linear(128, 3, bias=False), # nn.ReLU(), # nn.Linear(3, 64, bias=False), # nn.Linear(64, 128, bias=False), # nn.Linear(128, 256, bias=False), # nn.Linear(256, 128, bias=False), # nn.ReLU(), # nn.Linear(128, 1, bias=False), # ) # def forward(self, x): # x = self.layers(x) # x = (torch.sin(x) + 1) / 2 # return x # class ColorTransformerModel(pl.LightningModule): # def __init__(self, params): # super().__init__() # self.save_hyperparameters(params) # # Embedding layer to expand the input dimensions # self.embedding = nn.Linear(3, 128, bias=False) # # Transformer encoder-decoder # encoder = nn.TransformerEncoderLayer( # d_model=128, nhead=4, dim_feedforward=512, dropout=0.3 # ) # self.transformer_encoder = nn.TransformerEncoder( # encoder, num_layers=3 # ) # # lower dimensionality decoder # decoder = nn.TransformerDecoderLayer( # d_model=128, nhead=4, dim_feedforward=512, dropout=0.3 # ) # self.transformer_decoder = nn.TransformerDecoder( # decoder, num_layers=3 # ) # # Final linear layer to map back to 1D space # self.final_layer = nn.Linear(128, 1, bias=False) # def forward(self, x): # # Embedding the input # x = self.embedding(x) # # Adjusting the shape for the transformer # x = x.unsqueeze(1) # Adding a fake sequence dimension # # Passing through the transformer # x = self.transformer_encoder(x) # # Passing through the decoder # x = self.transformer_decoder(x, memory=x) # # Reshape back to original shape # x = x.squeeze(1) # # Final linear layer # x = self.final_layer(x) # # Apply sigmoid activation to ensure output is in (0, 1) # # x = torch.sigmoid(x) # x = (torch.sin(x) + 1) / 2 # return x class ColorTransformerModel(pl.LightningModule): def __init__(self, params): super().__init__() self.save_hyperparameters(params) # Neural network layers self.network = nn.Sequential( nn.Linear(3, self.hparams.width), nn.ReLU(), nn.Linear(self.hparams.width, 64), nn.ReLU(), nn.Linear(64, 1), ) def forward(self, x): # Pass the input through the network x = self.network(x) # Circular mapping # x = (torch.sin(x) + 1) / 2 x = torch.sigmoid(x) return x def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) s_loss = calculate_separation_loss(model=self) p_loss = preservation_loss( inputs, outputs, ) alpha = self.hparams.alpha loss = p_loss + alpha * s_loss self.log("hp_metric", loss) self.log("p_loss", p_loss) self.log("s_loss", s_loss) return loss def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), lr=self.hparams.learning_rate, ) lr_scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.05, patience=10, cooldown=20, verbose=True ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "monitor": "hp_metric", # Specify the metric to monitor }, }