You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.2 KiB
40 lines
1.2 KiB
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from losses import enhanced_loss, weighted_loss
|
|
|
|
|
|
class ColorTransformerModel(pl.LightningModule):
|
|
def __init__(self, alpha, distinct_threshold, learning_rate):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
# Model layers
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(3, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 3),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
|
outputs = self.forward(inputs)
|
|
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha)
|
|
loss = enhanced_loss(
|
|
inputs,
|
|
outputs,
|
|
alpha=self.hparams.alpha,
|
|
distinct_threshold=self.hparams.distinct_threshold,
|
|
)
|
|
self.log("train_loss", loss)
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
|
return optimizer
|
|
|