You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.3 KiB

import subprocess
10 months ago
from pathlib import Path
from lightning import Callback
10 months ago
from check import create_circle
class SaveImageCallback(Callback):
def __init__(
self,
save_interval=1,
final_dir: str = None,
radius: float = 0.5,
dpi: int = 300,
figsize=(15, 15), # 4k is at 14.4
):
10 months ago
self.save_interval = save_interval
10 months ago
self.final_dir = final_dir
self.radius = radius
self.dpi = dpi
self.figsize = figsize
self.processes = []
10 months ago
10 months ago
def on_train_epoch_end(self, trainer, pl_module):
10 months ago
epoch = trainer.current_epoch
10 months ago
if self.save_interval <= 0:
return None
10 months ago
if epoch % self.save_interval == 0:
# Set the model to eval mode for generating the image
pl_module.eval()
# Save the image
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
10 months ago
# Make sure to set it back to train mode
pl_module.train()
10 months ago
def on_train_end(self, trainer, pl_module):
if self.final_dir:
version = pl_module.trainer.logger.version
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval()
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
# Wait for all subprocesses to finish
for p in self.processes:
p.join()
10 months ago
if self.save_interval > 0:
10 months ago
log_dir = str(Path(pl_module.trainer.logger.log_dir))
fps = 12
# w, h = self.figsize[0] * self.dpi, self.figsize[1] * self.dpi
w, h = 7680, 4320
_cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4'
_ = subprocess.Popen(_cmd, shell=True)