|
@ -3,7 +3,7 @@ import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
|
|
|
|
|
|
from losses import circle_norm, preservation_loss |
|
|
from losses import circle_norm, preservation_loss # noqa: F401 |
|
|
from utils import RGBMYC_ANCHOR |
|
|
from utils import RGBMYC_ANCHOR |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -28,7 +28,10 @@ class ColorTransformerModel(L.LightningModule): |
|
|
w = self.hparams.width |
|
|
w = self.hparams.width |
|
|
d = self.hparams.depth |
|
|
d = self.hparams.depth |
|
|
bias = self.hparams.bias |
|
|
bias = self.hparams.bias |
|
|
midlayers = [nn.Linear(w, w, bias=bias), t()] * d |
|
|
midlayers = [] |
|
|
|
|
|
for _ in range(d): |
|
|
|
|
|
midlayers += [nn.Linear(w, w, bias=bias), t()] |
|
|
|
|
|
|
|
|
self.network = nn.Sequential( |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(3, w, bias=bias), |
|
|
nn.Linear(3, w, bias=bias), |
|
|
t(), |
|
|
t(), |
|
|