You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

76 lines
2.2 KiB

import pytorch_lightning as pl
import torch
import torch.nn as nn
from losses import enhanced_loss, weighted_loss
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, distinct_threshold, learning_rate):
# super().__init__()
# self.save_hyperparameters()
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(3, 128),
# nn.ReLU(),
# nn.Linear(128, 128),
# nn.ReLU(),
# nn.Linear(128, 1),
# )
# def forward(self, x):
# return self.layers(x)
class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, distinct_threshold, learning_rate):
super().__init__()
self.save_hyperparameters()
# Embedding layer to expand the input dimensions
self.embedding = nn.Linear(3, 128)
# Transformer block
transformer_layer = nn.TransformerEncoderLayer(
d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
)
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=3)
# Final linear layer to map back to 1D space
self.final_layer = nn.Linear(128, 1)
def forward(self, x):
# Embedding the input
x = self.embedding(x)
# Adjusting the shape for the transformer
x = x.unsqueeze(1) # Adding a fake sequence dimension
# Passing through the transformer
x = self.transformer_encoder(x)
# Reshape back to original shape
x = x.squeeze(1)
# Final linear layer
x = self.final_layer(x)
return x
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha)
loss = enhanced_loss(
inputs,
outputs,
alpha=self.hparams.alpha,
distinct_threshold=self.hparams.distinct_threshold,
)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer