diff --git a/model.py b/model.py index dd6029d..775fdda 100644 --- a/model.py +++ b/model.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import circle_norm, preservation_loss +from losses import circle_norm, preservation_loss # noqa: F401 from utils import RGBMYC_ANCHOR @@ -28,7 +28,10 @@ class ColorTransformerModel(L.LightningModule): w = self.hparams.width d = self.hparams.depth bias = self.hparams.bias - midlayers = [nn.Linear(w, w, bias=bias), t()] * d + midlayers = [] + for _ in range(d): + midlayers += [nn.Linear(w, w, bias=bias), t()] + self.network = nn.Sequential( nn.Linear(3, w, bias=bias), t(), diff --git a/newmain.py b/newmain.py index eec05a7..dbf4658 100644 --- a/newmain.py +++ b/newmain.py @@ -13,16 +13,3 @@ def cli_main(): if __name__ == "__main__": cli_main() # note: it is good practice to implement the CLI in a function and call it in the main if block - - -# save_img_callback = SaveImageCallback( -# save_interval=0, -# final_dir="out", -# ) - -# trainer = pl.Trainer( -# callbacks=[save_img_callback], -# ) - -# data = ColorDataModule() -# trainer.fit(model, data) diff --git a/out/index.html b/out/index.html index 1209be9..70e1fbd 100644 --- a/out/index.html +++ b/out/index.html @@ -72,7 +72,7 @@ function loadImages() { var gallery = document.getElementById('gallery'); - for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100 + for (var i = 0; i < 200; i++) { // Changed from i <= 100 to i < 100 let imageName; if (i == -21) { imageName = 'hsv.png';