Browse Source

sequential instead of loopback

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
05dd4e29ce
  1. 7
      model.py
  2. 13
      newmain.py
  3. 2
      out/index.html

7
model.py

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import circle_norm, preservation_loss from losses import circle_norm, preservation_loss # noqa: F401
from utils import RGBMYC_ANCHOR from utils import RGBMYC_ANCHOR
@ -28,7 +28,10 @@ class ColorTransformerModel(L.LightningModule):
w = self.hparams.width w = self.hparams.width
d = self.hparams.depth d = self.hparams.depth
bias = self.hparams.bias bias = self.hparams.bias
midlayers = [nn.Linear(w, w, bias=bias), t()] * d midlayers = []
for _ in range(d):
midlayers += [nn.Linear(w, w, bias=bias), t()]
self.network = nn.Sequential( self.network = nn.Sequential(
nn.Linear(3, w, bias=bias), nn.Linear(3, w, bias=bias),
t(), t(),

13
newmain.py

@ -13,16 +13,3 @@ def cli_main():
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block # note: it is good practice to implement the CLI in a function and call it in the main if block
# save_img_callback = SaveImageCallback(
# save_interval=0,
# final_dir="out",
# )
# trainer = pl.Trainer(
# callbacks=[save_img_callback],
# )
# data = ColorDataModule()
# trainer.fit(model, data)

2
out/index.html

@ -72,7 +72,7 @@
function loadImages() { function loadImages() {
var gallery = document.getElementById('gallery'); var gallery = document.getElementById('gallery');
for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100 for (var i = 0; i < 200; i++) { // Changed from i <= 100 to i < 100
let imageName; let imageName;
if (i == -21) { if (i == -21) {
imageName = 'hsv.png'; imageName = 'hsv.png';

Loading…
Cancel
Save