colors/callbacks.py

47 lines
1.5 KiB
Python
Raw Normal View History

2024-01-14 07:28:23 +00:00
from pathlib import Path
import pytorch_lightning as pl
from check import create_circle
class SaveImageCallback(pl.Callback):
2024-01-15 02:58:41 +00:00
def __init__(self, save_interval=1, final_dir: str = None):
2024-01-14 07:28:23 +00:00
self.save_interval = save_interval
2024-01-15 02:58:41 +00:00
self.final_dir = final_dir
2024-01-14 07:28:23 +00:00
2024-01-15 02:58:41 +00:00
def on_train_epoch_end(self, trainer, pl_module):
2024-01-14 07:28:23 +00:00
epoch = trainer.current_epoch
2024-01-15 02:58:41 +00:00
if self.save_interval <= 0:
return None
2024-01-14 07:28:23 +00:00
if epoch % self.save_interval == 0:
# Set the model to eval mode for generating the image
pl_module.eval()
# Save the image
# if pl_module.trainer.logger:
2024-01-15 02:58:41 +00:00
#
2024-01-14 07:28:23 +00:00
# else:
# version = 0
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
create_circle(pl_module, fname=fname)
# Make sure to set it back to train mode
pl_module.train()
2024-01-15 02:58:41 +00:00
def on_train_end(self, trainer, pl_module):
if self.final_dir:
version = pl_module.trainer.logger.version
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval()
create_circle(pl_module, fname=fname)
2024-01-15 03:26:46 +00:00
if self.save_interval > 0:
import os
log_dir = str(Path(pl_module.trainer.logger.log_dir))
fps = 12
os.system(
f'ffmpeg -i {log_dir}/e%04d.png -c:v libx264 -vf "fps={fps},format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" {log_dir}/a{version}.mp4'
)