You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
856 B
27 lines
856 B
from pathlib import Path
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from check import create_circle
|
|
|
|
|
|
class SaveImageCallback(pl.Callback):
|
|
def __init__(self, save_interval=1):
|
|
self.save_interval = save_interval
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
|
epoch = trainer.current_epoch
|
|
if epoch % self.save_interval == 0:
|
|
# Set the model to eval mode for generating the image
|
|
pl_module.eval()
|
|
|
|
# Save the image
|
|
# if pl_module.trainer.logger:
|
|
# version = pl_module.trainer.logger.version
|
|
# else:
|
|
# version = 0
|
|
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
|
|
create_circle(pl_module, fname=fname)
|
|
|
|
# Make sure to set it back to train mode
|
|
pl_module.train()
|
|
|