diff --git a/callbacks.py b/callbacks.py index 4f25b56..e3158c2 100644 --- a/callbacks.py +++ b/callbacks.py @@ -6,18 +6,22 @@ from check import create_circle class SaveImageCallback(pl.Callback): - def __init__(self, save_interval=1): + def __init__(self, save_interval=1, final_dir: str = None): self.save_interval = save_interval + self.final_dir = final_dir - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch + if self.save_interval <= 0: + return None + if epoch % self.save_interval == 0: # Set the model to eval mode for generating the image pl_module.eval() # Save the image # if pl_module.trainer.logger: - # version = pl_module.trainer.logger.version + # # else: # version = 0 fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") @@ -25,3 +29,10 @@ class SaveImageCallback(pl.Callback): # Make sure to set it back to train mode pl_module.train() + + def on_train_end(self, trainer, pl_module): + if self.final_dir: + version = pl_module.trainer.logger.version + fname = Path(f"{self.final_dir}") / Path(f"v{version}") + pl_module.eval() + create_circle(pl_module, fname=fname) diff --git a/check.py b/check.py index 9e3a44b..b70ed8f 100644 --- a/check.py +++ b/check.py @@ -70,7 +70,7 @@ def create_circle(ckpt: str, fname: str): ax.set_yticks([]) ax.axis("off") fig.tight_layout() - plt.savefig(f"{fname}.png", dpi=300) + plt.savefig(f"{fname}.png", dpi=150) plt.close() diff --git a/main.py b/main.py index 0161274..e2338d9 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import argparse import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping +from callbacks import SaveImageCallback from dataloader import create_named_dataloader from model import ColorTransformerModel @@ -51,6 +52,11 @@ if __name__ == "__main__": verbose=True, ) + save_img_callback = SaveImageCallback( + save_interval=0, + final_dir="out", + ) + # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed train_dataloader = create_named_dataloader( @@ -71,7 +77,7 @@ if __name__ == "__main__": # Initialize trainer with parsed arguments trainer = pl.Trainer( - callbacks=[early_stop_callback], + callbacks=[early_stop_callback, save_img_callback], max_epochs=args.max_epochs, log_every_n_steps=args.log_every_n_steps, ) diff --git a/makefile b/makefile index 7da0b5d..cff89bb 100644 --- a/makefile +++ b/makefile @@ -8,3 +8,9 @@ test: search: python search.py + +animate: + ffmpeg -i lightning_logs/version_258/e%04d.png \ + -c:v libx264 \ + -vf "fps=12,format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" \ + ~/animated.mp4 \ No newline at end of file diff --git a/out/index.html b/out/index.html index 3875c4a..f45c93b 100644 --- a/out/index.html +++ b/out/index.html @@ -72,7 +72,7 @@ function loadImages() { var gallery = document.getElementById('gallery'); - for (var i = 175; i < 275; i++) { // Changed from i <= 100 to i < 100 + for (var i = 1; i <= 100; i++) { // Changed from i <= 100 to i < 100 let imageName = 'v' + i + '.png'; let img = document.createElement('img'); img.src = imageName;