|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
|
|
|
|
from losses import enhanced_loss, weighted_loss
|
|
|
|
|
|
|
|
# class ColorTransformerModel(pl.LightningModule):
|
|
|
|
# def __init__(self, alpha, distinct_threshold, learning_rate):
|
|
|
|
# super().__init__()
|
|
|
|
# self.save_hyperparameters()
|
|
|
|
|
|
|
|
# # Model layers
|
|
|
|
# self.layers = nn.Sequential(
|
|
|
|
# nn.Linear(3, 128),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(128, 128),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(128, 1),
|
|
|
|
# )
|
|
|
|
|
|
|
|
# def forward(self, x):
|
|
|
|
# return self.layers(x)
|
|
|
|
|
|
|
|
|
|
|
|
class ColorTransformerModel(pl.LightningModule):
|
|
|
|
def __init__(self, alpha, learning_rate):
|
|
|
|
super().__init__()
|
|
|
|
self.save_hyperparameters()
|
|
|
|
|
|
|
|
# Embedding layer to expand the input dimensions
|
|
|
|
self.embedding = nn.Linear(3, 128)
|
|
|
|
|
|
|
|
# Transformer block
|
|
|
|
transformer_layer = nn.TransformerEncoderLayer(
|
|
|
|
d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
|
|
|
|
)
|
|
|
|
self.transformer_encoder = nn.TransformerEncoder(
|
|
|
|
transformer_layer, num_layers=3
|
|
|
|
)
|
|
|
|
|
|
|
|
# Final linear layer to map back to 1D space
|
|
|
|
self.final_layer = nn.Linear(128, 1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# Embedding the input
|
|
|
|
x = self.embedding(x)
|
|
|
|
|
|
|
|
# Adjusting the shape for the transformer
|
|
|
|
x = x.unsqueeze(1) # Adding a fake sequence dimension
|
|
|
|
|
|
|
|
# Passing through the transformer
|
|
|
|
x = self.transformer_encoder(x)
|
|
|
|
|
|
|
|
# Reshape back to original shape
|
|
|
|
x = x.squeeze(1)
|
|
|
|
|
|
|
|
# Final linear layer
|
|
|
|
x = self.final_layer(x)
|
|
|
|
|
|
|
|
# Apply sigmoid activation to ensure output is in (0, 1)
|
|
|
|
x = torch.sigmoid(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
|
|
|
outputs = self.forward(inputs)
|
|
|
|
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha)
|
|
|
|
loss = enhanced_loss(
|
|
|
|
inputs,
|
|
|
|
outputs,
|
|
|
|
alpha=self.hparams.alpha,
|
|
|
|
)
|
|
|
|
self.log("train_loss", loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2)
|
|
|
|
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
|
|
|
|
return {
|
|
|
|
'optimizer': optimizer,
|
|
|
|
'lr_scheduler': {
|
|
|
|
'scheduler': lr_scheduler,
|
|
|
|
'monitor': 'train_loss', # Specify the metric to monitor
|
|
|
|
}
|
|
|
|
}
|