|
|
@ -6,18 +6,22 @@ from check import create_circle |
|
|
|
|
|
|
|
|
|
|
|
class SaveImageCallback(pl.Callback): |
|
|
|
def __init__(self, save_interval=1): |
|
|
|
def __init__(self, save_interval=1, final_dir: str = None): |
|
|
|
self.save_interval = save_interval |
|
|
|
self.final_dir = final_dir |
|
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module, outputs): |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
|
|
epoch = trainer.current_epoch |
|
|
|
if self.save_interval <= 0: |
|
|
|
return None |
|
|
|
|
|
|
|
if epoch % self.save_interval == 0: |
|
|
|
# Set the model to eval mode for generating the image |
|
|
|
pl_module.eval() |
|
|
|
|
|
|
|
# Save the image |
|
|
|
# if pl_module.trainer.logger: |
|
|
|
# version = pl_module.trainer.logger.version |
|
|
|
# |
|
|
|
# else: |
|
|
|
# version = 0 |
|
|
|
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") |
|
|
@ -25,3 +29,10 @@ class SaveImageCallback(pl.Callback): |
|
|
|
|
|
|
|
# Make sure to set it back to train mode |
|
|
|
pl_module.train() |
|
|
|
|
|
|
|
def on_train_end(self, trainer, pl_module): |
|
|
|
if self.final_dir: |
|
|
|
version = pl_module.trainer.logger.version |
|
|
|
fname = Path(f"{self.final_dir}") / Path(f"v{version}") |
|
|
|
pl_module.eval() |
|
|
|
create_circle(pl_module, fname=fname) |
|
|
|