You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

41 lines
1.2 KiB

11 months ago
import pytorch_lightning as pl
import torch
import torch.nn as nn
from losses import enhanced_loss, weighted_loss
class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, distinct_threshold, learning_rate):
super().__init__()
self.save_hyperparameters()
# Model layers
self.layers = nn.Sequential(
nn.Linear(3, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 3),
)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha)
loss = enhanced_loss(
inputs,
outputs,
alpha=self.hparams.alpha,
distinct_threshold=self.hparams.distinct_threshold,
)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer