import pytorch_lightning as pl import torch import torch.nn as nn from losses import enhanced_loss, weighted_loss class ColorTransformerModel(pl.LightningModule): def __init__(self, alpha, distinct_threshold, learning_rate): super().__init__() self.save_hyperparameters() # Model layers self.layers = nn.Sequential( nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 3), ) def forward(self, x): return self.layers(x) def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) # loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) loss = enhanced_loss( inputs, outputs, alpha=self.hparams.alpha, distinct_threshold=self.hparams.distinct_threshold, ) self.log("train_loss", loss) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return optimizer