You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.9 KiB

11 months ago
import pytorch_lightning as pl
import torch
import torch.nn as nn
11 months ago
from torch.optim.lr_scheduler import ReduceLROnPlateau
11 months ago
11 months ago
from losses import preservation_loss, smoothness_loss
11 months ago
11 months ago
11 months ago
class ColorTransformerModel(pl.LightningModule):
11 months ago
def __init__(self, params):
11 months ago
super().__init__()
11 months ago
self.save_hyperparameters(params)
11 months ago
# Model layers
self.layers = nn.Sequential(
11 months ago
nn.Linear(5, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1),
11 months ago
)
11 months ago
11 months ago
def forward(self, x):
x = self.layers(x)
11 months ago
x = (torch.sin(x) + 1) / 2
return x
11 months ago
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, learning_rate):
# super().__init__()
# self.save_hyperparameters()
11 months ago
# # Embedding layer to expand the input dimensions
# self.embedding = nn.Linear(3, 128)
11 months ago
# # Transformer block
# transformer_layer = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
# )
# self.transformer_encoder = nn.TransformerEncoder(
# transformer_layer, num_layers=3
# )
11 months ago
# # Final linear layer to map back to 1D space
# self.final_layer = nn.Linear(128, 1)
11 months ago
# def forward(self, x):
# # Embedding the input
# x = self.embedding(x)
11 months ago
# # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# x = torch.sigmoid(x)
# return x
11 months ago
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
11 months ago
s_loss = smoothness_loss(outputs)
p_loss = preservation_loss(
11 months ago
inputs,
outputs,
)
11 months ago
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
self.log("hp_metric", p_loss)
11 months ago
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2
)
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=10, verbose=True
)
11 months ago
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "train_loss", # Specify the metric to monitor
},
}