import pytorch_lightning as pl import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import enhanced_loss, weighted_loss # class ColorTransformerModel(pl.LightningModule): # def __init__(self, alpha, distinct_threshold, learning_rate): # super().__init__() # self.save_hyperparameters() # # Model layers # self.layers = nn.Sequential( # nn.Linear(3, 128), # nn.ReLU(), # nn.Linear(128, 128), # nn.ReLU(), # nn.Linear(128, 1), # ) # def forward(self, x): # return self.layers(x) class ColorTransformerModel(pl.LightningModule): def __init__(self, alpha, learning_rate): super().__init__() self.save_hyperparameters() # Embedding layer to expand the input dimensions self.embedding = nn.Linear(3, 128) # Transformer block transformer_layer = nn.TransformerEncoderLayer( d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 ) self.transformer_encoder = nn.TransformerEncoder( transformer_layer, num_layers=3 ) # Final linear layer to map back to 1D space self.final_layer = nn.Linear(128, 1) def forward(self, x): # Embedding the input x = self.embedding(x) # Adjusting the shape for the transformer x = x.unsqueeze(1) # Adding a fake sequence dimension # Passing through the transformer x = self.transformer_encoder(x) # Reshape back to original shape x = x.squeeze(1) # Final linear layer x = self.final_layer(x) # Apply sigmoid activation to ensure output is in (0, 1) x = torch.sigmoid(x) return x def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) # loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) loss = enhanced_loss( inputs, outputs, alpha=self.hparams.alpha, ) self.log("train_loss", loss) return loss def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2) lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': lr_scheduler, 'monitor': 'train_loss', # Specify the metric to monitor } }