from pathlib import Path import pytorch_lightning as pl from check import create_circle class SaveImageCallback(pl.Callback): def __init__(self, save_interval=1): self.save_interval = save_interval def on_train_epoch_end(self, trainer, pl_module, outputs): epoch = trainer.current_epoch if epoch % self.save_interval == 0: # Set the model to eval mode for generating the image pl_module.eval() # Save the image # if pl_module.trainer.logger: # version = pl_module.trainer.logger.version # else: # version = 0 fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") create_circle(pl_module, fname=fname) # Make sure to set it back to train mode pl_module.train()