Browse Source

working callback

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
c5e8a43116
  1. 17
      callbacks.py
  2. 2
      check.py
  3. 8
      main.py
  4. 6
      makefile
  5. 2
      out/index.html

17
callbacks.py

@ -6,18 +6,22 @@ from check import create_circle
class SaveImageCallback(pl.Callback):
def __init__(self, save_interval=1):
def __init__(self, save_interval=1, final_dir: str = None):
self.save_interval = save_interval
self.final_dir = final_dir
def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if self.save_interval <= 0:
return None
if epoch % self.save_interval == 0:
# Set the model to eval mode for generating the image
pl_module.eval()
# Save the image
# if pl_module.trainer.logger:
# version = pl_module.trainer.logger.version
#
# else:
# version = 0
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
@ -25,3 +29,10 @@ class SaveImageCallback(pl.Callback):
# Make sure to set it back to train mode
pl_module.train()
def on_train_end(self, trainer, pl_module):
if self.final_dir:
version = pl_module.trainer.logger.version
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval()
create_circle(pl_module, fname=fname)

2
check.py

@ -70,7 +70,7 @@ def create_circle(ckpt: str, fname: str):
ax.set_yticks([])
ax.axis("off")
fig.tight_layout()
plt.savefig(f"{fname}.png", dpi=300)
plt.savefig(f"{fname}.png", dpi=150)
plt.close()

8
main.py

@ -3,6 +3,7 @@ import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from callbacks import SaveImageCallback
from dataloader import create_named_dataloader
from model import ColorTransformerModel
@ -51,6 +52,11 @@ if __name__ == "__main__":
verbose=True,
)
save_img_callback = SaveImageCallback(
save_interval=0,
final_dir="out",
)
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_named_dataloader(
@ -71,7 +77,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
callbacks=[early_stop_callback],
callbacks=[early_stop_callback, save_img_callback],
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)

6
makefile

@ -8,3 +8,9 @@ test:
search:
python search.py
animate:
ffmpeg -i lightning_logs/version_258/e%04d.png \
-c:v libx264 \
-vf "fps=12,format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" \
~/animated.mp4

2
out/index.html

@ -72,7 +72,7 @@
function loadImages() {
var gallery = document.getElementById('gallery');
for (var i = 175; i < 275; i++) { // Changed from i <= 100 to i < 100
for (var i = 1; i <= 100; i++) { // Changed from i <= 100 to i < 100
let imageName = 'v' + i + '.png';
let img = document.createElement('img');
img.src = imageName;

Loading…
Cancel
Save