Browse Source

add callback

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
e36126961b
  1. 27
      callbacks.py
  2. 6
      check.py

27
callbacks.py

@ -0,0 +1,27 @@
from pathlib import Path
import pytorch_lightning as pl
from check import create_circle
class SaveImageCallback(pl.Callback):
def __init__(self, save_interval=1):
self.save_interval = save_interval
def on_train_epoch_end(self, trainer, pl_module, outputs):
epoch = trainer.current_epoch
if epoch % self.save_interval == 0:
# Set the model to eval mode for generating the image
pl_module.eval()
# Save the image
# if pl_module.trainer.logger:
# version = pl_module.trainer.logger.version
# else:
# version = 0
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
create_circle(pl_module, fname=fname)
# Make sure to set it back to train mode
pl_module.train()

6
check.py

@ -39,8 +39,10 @@ def make_image(ckpt: str, fname: str, color=True):
def create_circle(ckpt: str, fname: str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
M.eval()
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
M = ckpt
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)

Loading…
Cancel
Save