Michael Pilosov
10 months ago
2 changed files with 31 additions and 2 deletions
@ -0,0 +1,27 @@ |
|||
from pathlib import Path |
|||
|
|||
import pytorch_lightning as pl |
|||
|
|||
from check import create_circle |
|||
|
|||
|
|||
class SaveImageCallback(pl.Callback): |
|||
def __init__(self, save_interval=1): |
|||
self.save_interval = save_interval |
|||
|
|||
def on_train_epoch_end(self, trainer, pl_module, outputs): |
|||
epoch = trainer.current_epoch |
|||
if epoch % self.save_interval == 0: |
|||
# Set the model to eval mode for generating the image |
|||
pl_module.eval() |
|||
|
|||
# Save the image |
|||
# if pl_module.trainer.logger: |
|||
# version = pl_module.trainer.logger.version |
|||
# else: |
|||
# version = 0 |
|||
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") |
|||
create_circle(pl_module, fname=fname) |
|||
|
|||
# Make sure to set it back to train mode |
|||
pl_module.train() |
Loading…
Reference in new issue