|
|
@ -3,7 +3,8 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
|
|
|
|
|
from losses import calculate_separation_loss, preservation_loss |
|
|
|
from losses import calculate_separation_loss, preservation_loss # noqa: F401 |
|
|
|
from utils import PURE_HSV, PURE_RGB |
|
|
|
|
|
|
|
# class ColorTransformerModel(pl.LightningModule): |
|
|
|
# def __init__(self, params): |
|
|
@ -83,18 +84,40 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
|
def __init__(self, params): |
|
|
|
super().__init__() |
|
|
|
self.save_hyperparameters(params) |
|
|
|
|
|
|
|
# self.a = nn.Sequential( |
|
|
|
# nn.Linear(3, 3, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# nn.Linear(3, 3, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# nn.Linear(3, 1, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# ) |
|
|
|
# self.b = nn.Sequential( |
|
|
|
# nn.Linear(3, 3, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# nn.Linear(3, 3, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# nn.Linear(3, 1, bias=False), |
|
|
|
# nn.ReLU(), |
|
|
|
# ) |
|
|
|
# Neural network layers |
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(3, self.hparams.width), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(self.hparams.width, 64), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(64, 1), |
|
|
|
nn.Linear(5, 64), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(64, self.hparams.width), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(self.hparams.width, 3), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(3, 1), |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
# Pass the input through the network |
|
|
|
# a = self.a(x) |
|
|
|
# b = self.b(x) |
|
|
|
# a = torch.sigmoid(a) |
|
|
|
# b = torch.sigmoid(b) |
|
|
|
# x = torch.cat([x, a, b], dim=-1) |
|
|
|
x = self.network(x) |
|
|
|
# Circular mapping |
|
|
|
# x = (torch.sin(x) + 1) / 2 |
|
|
@ -104,7 +127,14 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
inputs, labels = batch # x are the RGB inputs, labels are the strings |
|
|
|
outputs = self.forward(inputs) |
|
|
|
s_loss = calculate_separation_loss(model=self) |
|
|
|
# s_loss = calculate_separation_loss(model=self) |
|
|
|
# preserve distance to pure R, G, B. this acts kind of like labeled data. |
|
|
|
s_loss = preservation_loss( |
|
|
|
inputs, |
|
|
|
outputs, |
|
|
|
target_inputs=PURE_RGB, |
|
|
|
target_outputs=PURE_HSV, |
|
|
|
) |
|
|
|
p_loss = preservation_loss( |
|
|
|
inputs, |
|
|
|
outputs, |
|
|
|