|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from lightning import Callback
|
|
|
|
|
|
|
|
from check import create_circle
|
|
|
|
|
|
|
|
|
|
|
|
class SaveImageCallback(Callback):
|
|
|
|
def __init__(self, save_interval=1, final_dir: str = None):
|
|
|
|
self.save_interval = save_interval
|
|
|
|
self.final_dir = final_dir
|
|
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
|
|
epoch = trainer.current_epoch
|
|
|
|
if self.save_interval <= 0:
|
|
|
|
return None
|
|
|
|
|
|
|
|
if epoch % self.save_interval == 0:
|
|
|
|
# Set the model to eval mode for generating the image
|
|
|
|
pl_module.eval()
|
|
|
|
|
|
|
|
# Save the image
|
|
|
|
# if pl_module.trainer.logger:
|
|
|
|
#
|
|
|
|
# else:
|
|
|
|
# version = 0
|
|
|
|
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
|
|
|
|
create_circle(pl_module, fname=fname)
|
|
|
|
|
|
|
|
# Make sure to set it back to train mode
|
|
|
|
pl_module.train()
|
|
|
|
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
|
|
if self.final_dir:
|
|
|
|
version = pl_module.trainer.logger.version
|
|
|
|
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
|
|
|
|
pl_module.eval()
|
|
|
|
create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6))
|
|
|
|
if self.save_interval > 0:
|
|
|
|
import os
|
|
|
|
|
|
|
|
log_dir = str(Path(pl_module.trainer.logger.log_dir))
|
|
|
|
fps = 12
|
|
|
|
_cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale=1920:1080:force_original_aspect_ratio=decrease,pad=1920:1080:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4'
|
|
|
|
os.system(_cmd)
|
|
|
|
# os.system(
|
|
|
|
# f'ffmpeg -i {log_dir}/e%04d.png -c:v libx264 -vf "fps={fps},format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" {log_dir}/a{version}.mp4'
|
|
|
|
# )
|