2023-12-30 04:37:06 +00:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-12-30 05:30:52 +00:00
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
2023-12-30 04:37:06 +00:00
|
|
|
|
2024-01-16 04:37:22 +00:00
|
|
|
from losses import calculate_separation_loss, preservation_loss # noqa: F401
|
|
|
|
from utils import PURE_HSV, PURE_RGB
|
2023-12-30 05:13:50 +00:00
|
|
|
|
2024-01-14 06:04:19 +00:00
|
|
|
# class ColorTransformerModel(pl.LightningModule):
|
|
|
|
# def __init__(self, params):
|
|
|
|
# super().__init__()
|
|
|
|
# self.save_hyperparameters(params)
|
|
|
|
|
|
|
|
# # Model layers
|
|
|
|
# self.layers = nn.Sequential(
|
|
|
|
# nn.Linear(5, 128, bias=False),
|
|
|
|
# nn.Linear(128, 3, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(3, 64, bias=False),
|
|
|
|
# nn.Linear(64, 128, bias=False),
|
|
|
|
# nn.Linear(128, 256, bias=False),
|
|
|
|
# nn.Linear(256, 128, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(128, 1, bias=False),
|
|
|
|
# )
|
|
|
|
|
|
|
|
# def forward(self, x):
|
|
|
|
# x = self.layers(x)
|
|
|
|
# x = (torch.sin(x) + 1) / 2
|
|
|
|
# return x
|
|
|
|
|
|
|
|
# class ColorTransformerModel(pl.LightningModule):
|
|
|
|
# def __init__(self, params):
|
|
|
|
# super().__init__()
|
|
|
|
# self.save_hyperparameters(params)
|
|
|
|
|
|
|
|
# # Embedding layer to expand the input dimensions
|
|
|
|
# self.embedding = nn.Linear(3, 128, bias=False)
|
|
|
|
|
|
|
|
# # Transformer encoder-decoder
|
|
|
|
# encoder = nn.TransformerEncoderLayer(
|
|
|
|
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
|
|
|
|
# )
|
|
|
|
# self.transformer_encoder = nn.TransformerEncoder(
|
|
|
|
# encoder, num_layers=3
|
|
|
|
# )
|
|
|
|
# # lower dimensionality decoder
|
|
|
|
# decoder = nn.TransformerDecoderLayer(
|
|
|
|
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
|
|
|
|
# )
|
|
|
|
# self.transformer_decoder = nn.TransformerDecoder(
|
|
|
|
# decoder, num_layers=3
|
|
|
|
# )
|
|
|
|
|
|
|
|
# # Final linear layer to map back to 1D space
|
|
|
|
# self.final_layer = nn.Linear(128, 1, bias=False)
|
|
|
|
|
|
|
|
# def forward(self, x):
|
|
|
|
# # Embedding the input
|
|
|
|
# x = self.embedding(x)
|
|
|
|
|
|
|
|
# # Adjusting the shape for the transformer
|
|
|
|
# x = x.unsqueeze(1) # Adding a fake sequence dimension
|
|
|
|
|
|
|
|
# # Passing through the transformer
|
|
|
|
# x = self.transformer_encoder(x)
|
|
|
|
|
|
|
|
# # Passing through the decoder
|
|
|
|
# x = self.transformer_decoder(x, memory=x)
|
|
|
|
|
|
|
|
# # Reshape back to original shape
|
|
|
|
# x = x.squeeze(1)
|
|
|
|
|
|
|
|
# # Final linear layer
|
|
|
|
# x = self.final_layer(x)
|
|
|
|
|
|
|
|
# # Apply sigmoid activation to ensure output is in (0, 1)
|
|
|
|
# # x = torch.sigmoid(x)
|
|
|
|
# x = (torch.sin(x) + 1) / 2
|
|
|
|
# return x
|
|
|
|
|
2023-12-30 05:30:52 +00:00
|
|
|
|
2023-12-30 04:37:06 +00:00
|
|
|
class ColorTransformerModel(pl.LightningModule):
|
2023-12-31 06:17:15 +00:00
|
|
|
def __init__(self, params):
|
2023-12-30 04:37:06 +00:00
|
|
|
super().__init__()
|
2023-12-31 06:17:15 +00:00
|
|
|
self.save_hyperparameters(params)
|
2024-01-16 04:37:22 +00:00
|
|
|
# self.a = nn.Sequential(
|
|
|
|
# nn.Linear(3, 3, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(3, 3, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(3, 1, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# )
|
|
|
|
# self.b = nn.Sequential(
|
|
|
|
# nn.Linear(3, 3, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(3, 3, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# nn.Linear(3, 1, bias=False),
|
|
|
|
# nn.ReLU(),
|
|
|
|
# )
|
2024-01-14 06:04:19 +00:00
|
|
|
# Neural network layers
|
|
|
|
self.network = nn.Sequential(
|
2024-01-16 04:37:22 +00:00
|
|
|
nn.Linear(5, 64),
|
|
|
|
nn.Tanh(),
|
|
|
|
nn.Linear(64, self.hparams.width),
|
|
|
|
nn.Tanh(),
|
|
|
|
nn.Linear(self.hparams.width, 3),
|
|
|
|
nn.Tanh(),
|
|
|
|
nn.Linear(3, 1),
|
2023-12-30 05:30:52 +00:00
|
|
|
)
|
2023-12-30 05:13:50 +00:00
|
|
|
|
2023-12-30 04:37:06 +00:00
|
|
|
def forward(self, x):
|
2024-01-14 06:04:19 +00:00
|
|
|
# Pass the input through the network
|
2024-01-16 04:37:22 +00:00
|
|
|
# a = self.a(x)
|
|
|
|
# b = self.b(x)
|
|
|
|
# a = torch.sigmoid(a)
|
|
|
|
# b = torch.sigmoid(b)
|
|
|
|
# x = torch.cat([x, a, b], dim=-1)
|
2024-01-14 06:04:19 +00:00
|
|
|
x = self.network(x)
|
|
|
|
# Circular mapping
|
|
|
|
# x = (torch.sin(x) + 1) / 2
|
|
|
|
x = torch.sigmoid(x)
|
2023-12-30 06:35:19 +00:00
|
|
|
return x
|
2023-12-30 05:13:50 +00:00
|
|
|
|
2023-12-30 04:37:06 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
|
|
|
outputs = self.forward(inputs)
|
2024-01-16 04:37:22 +00:00
|
|
|
# s_loss = calculate_separation_loss(model=self)
|
|
|
|
# preserve distance to pure R, G, B. this acts kind of like labeled data.
|
|
|
|
s_loss = preservation_loss(
|
|
|
|
inputs,
|
|
|
|
outputs,
|
|
|
|
target_inputs=PURE_RGB,
|
|
|
|
target_outputs=PURE_HSV,
|
|
|
|
)
|
2023-12-31 06:17:15 +00:00
|
|
|
p_loss = preservation_loss(
|
2023-12-30 04:37:06 +00:00
|
|
|
inputs,
|
|
|
|
outputs,
|
|
|
|
)
|
2023-12-31 06:17:15 +00:00
|
|
|
alpha = self.hparams.alpha
|
2024-01-15 03:26:46 +00:00
|
|
|
loss = p_loss + alpha * s_loss
|
2024-01-14 03:11:49 +00:00
|
|
|
self.log("hp_metric", loss)
|
2024-01-15 03:26:46 +00:00
|
|
|
self.log("p_loss", p_loss)
|
|
|
|
self.log("s_loss", s_loss)
|
2023-12-30 04:37:06 +00:00
|
|
|
return loss
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
2024-01-14 06:04:19 +00:00
|
|
|
optimizer = torch.optim.SGD(
|
|
|
|
self.parameters(),
|
|
|
|
lr=self.hparams.learning_rate,
|
2023-12-30 06:35:19 +00:00
|
|
|
)
|
|
|
|
lr_scheduler = ReduceLROnPlateau(
|
2024-01-15 19:18:28 +00:00
|
|
|
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
|
2023-12-30 06:35:19 +00:00
|
|
|
)
|
2023-12-30 05:30:52 +00:00
|
|
|
return {
|
2023-12-30 06:35:19 +00:00
|
|
|
"optimizer": optimizer,
|
|
|
|
"lr_scheduler": {
|
|
|
|
"scheduler": lr_scheduler,
|
2024-01-15 03:26:46 +00:00
|
|
|
"monitor": "hp_metric", # Specify the metric to monitor
|
2023-12-30 06:35:19 +00:00
|
|
|
},
|
|
|
|
}
|