diff --git a/callbacks.py b/callbacks.py new file mode 100644 index 0000000..4f25b56 --- /dev/null +++ b/callbacks.py @@ -0,0 +1,27 @@ +from pathlib import Path + +import pytorch_lightning as pl + +from check import create_circle + + +class SaveImageCallback(pl.Callback): + def __init__(self, save_interval=1): + self.save_interval = save_interval + + def on_train_epoch_end(self, trainer, pl_module, outputs): + epoch = trainer.current_epoch + if epoch % self.save_interval == 0: + # Set the model to eval mode for generating the image + pl_module.eval() + + # Save the image + # if pl_module.trainer.logger: + # version = pl_module.trainer.logger.version + # else: + # version = 0 + fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") + create_circle(pl_module, fname=fname) + + # Make sure to set it back to train mode + pl_module.train() diff --git a/check.py b/check.py index 29d2e8d..9e3a44b 100644 --- a/check.py +++ b/check.py @@ -39,8 +39,10 @@ def make_image(ckpt: str, fname: str, color=True): def create_circle(ckpt: str, fname: str): - M = ColorTransformerModel.load_from_checkpoint(ckpt) - M.eval() + if isinstance(ckpt, str): + M = ColorTransformerModel.load_from_checkpoint(ckpt) + else: + M = ckpt rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor)