You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

74 lines
2.4 KiB

# import subprocess
import os
from pathlib import Path
from lightning import Callback
from check import create_circle
class SaveImageCallback(Callback):
def __init__(
self,
save_interval=1,
final_dir: str = None,
radius: float = 0.5,
dpi: int = 300,
figsize=(15, 15), # 4k is at 14.4
):
self.save_interval = save_interval
self.final_dir = final_dir
self.radius = radius
self.dpi = dpi
self.figsize = figsize
self.processes = []
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if self.save_interval <= 0:
return None
if epoch % self.save_interval == 0:
# Set the model to eval mode for generating the image
pl_module.eval()
# Save the image
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
# Make sure to set it back to train mode
pl_module.train()
def on_train_end(self, trainer, pl_module):
if self.final_dir:
version = pl_module.trainer.logger.version
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval()
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
# Wait for all subprocesses to finish
for p in self.processes:
p.join()
if self.save_interval > 0:
log_dir = str(Path(pl_module.trainer.logger.log_dir))
fps = 12
# w, h = self.figsize[0] * self.dpi, self.figsize[1] * self.dpi
w, h = 7680, 4320
_cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4'
# _ = subprocess.Popen(_cmd, shell=True)
os.system(_cmd)