From 2582bdbbecb66633484faa23dd78b2afff04a174 Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sun, 3 Mar 2024 20:11:25 +0000 Subject: [PATCH] much faster animations --- baseline.py | 2 +- callbacks.py | 52 ++++++++++++++++++++++++++++++++------------ check.py | 45 +++++++++++++++++++++++++++++++++----- newsearch.py | 6 ++--- scripts/sortcolor.py | 9 ++++---- 5 files changed, 85 insertions(+), 29 deletions(-) diff --git a/baseline.py b/baseline.py index 2a5f9ab..adca447 100644 --- a/baseline.py +++ b/baseline.py @@ -120,7 +120,7 @@ plot_preds( fname, roll=False, dpi=DPI, - inner_radius=INNER_RADIUS, + radius=INNER_RADIUS, figsize=(SIZE, SIZE), fsize=FONTSIZE, label=f"{KIND.upper()}", diff --git a/callbacks.py b/callbacks.py index 39a4309..4ceb060 100644 --- a/callbacks.py +++ b/callbacks.py @@ -1,3 +1,4 @@ +import subprocess from pathlib import Path from lightning import Callback @@ -6,9 +7,20 @@ from check import create_circle class SaveImageCallback(Callback): - def __init__(self, save_interval=1, final_dir: str = None): + def __init__( + self, + save_interval=1, + final_dir: str = None, + radius: float = 0.5, + dpi: int = 300, + figsize=(15, 15), # 4k is at 14.4 + ): self.save_interval = save_interval self.final_dir = final_dir + self.radius = radius + self.dpi = dpi + self.figsize = figsize + self.processes = [] def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch @@ -20,12 +32,15 @@ class SaveImageCallback(Callback): pl_module.eval() # Save the image - # if pl_module.trainer.logger: - # - # else: - # version = 0 fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}") - create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6)) + p = create_circle( + pl_module, + fname=fname, + dpi=self.dpi, + figsize=self.figsize, + radius=self.radius, + ) + self.processes.append(p) # Make sure to set it back to train mode pl_module.train() @@ -35,14 +50,23 @@ class SaveImageCallback(Callback): version = pl_module.trainer.logger.version fname = Path(f"{self.final_dir}") / Path(f"v{version}") pl_module.eval() - create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6)) - if self.save_interval > 0: - import os + p = create_circle( + pl_module, + fname=fname, + dpi=self.dpi, + figsize=self.figsize, + radius=self.radius, + ) + self.processes.append(p) + + # Wait for all subprocesses to finish + for p in self.processes: + p.join() + if self.save_interval > 0: log_dir = str(Path(pl_module.trainer.logger.log_dir)) fps = 12 - _cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale=1920:1080:force_original_aspect_ratio=decrease,pad=1920:1080:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4' - os.system(_cmd) - # os.system( - # f'ffmpeg -i {log_dir}/e%04d.png -c:v libx264 -vf "fps={fps},format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" {log_dir}/a{version}.mp4' - # ) + # w, h = self.figsize[0] * self.dpi, self.figsize[1] * self.dpi + w, h = 7680, 4320 + _cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4' + _ = subprocess.Popen(_cmd, shell=True) diff --git a/check.py b/check.py index c15f542..ba10d53 100644 --- a/check.py +++ b/check.py @@ -1,4 +1,6 @@ # import matplotlib.patches as patches +import pickle +from multiprocessing import Process from pathlib import Path from typing import Tuple, Union @@ -42,11 +44,34 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs): plt.savefig(f"{fname}.png", **kwargs) +# def create_circle( +# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs +# ): +# if isinstance(ckpt, str): + +# M = ColorTransformerModel.load_from_checkpoint( +# ckpt, map_location=lambda storage, loc: storage +# ) +# else: +# M = ckpt + +# xkcd_colors, _ = extract_colors() +# xkcd_colors = preprocess_data(xkcd_colors).to(M.device) +# preds = M(xkcd_colors).detach().cpu().numpy() +# rgb_array = xkcd_colors.detach().cpu().numpy() +# plot_preds(preds, rgb_array, fname=fname, **kwargs) + + +def plot_preds_serialized(serialized_data, fname, **kwargs): + # Deserialize the data + preds, rgb_array = pickle.loads(serialized_data) + plot_preds(preds, rgb_array, fname=fname, **kwargs) + + def create_circle( ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs ): if isinstance(ckpt, str): - M = ColorTransformerModel.load_from_checkpoint( ckpt, map_location=lambda storage, loc: storage ) @@ -57,7 +82,16 @@ def create_circle( xkcd_colors = preprocess_data(xkcd_colors).to(M.device) preds = M(xkcd_colors).detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy() - plot_preds(preds, rgb_array, fname=fname, **kwargs) + + # Serialize the data + serialized_data = pickle.dumps((preds, rgb_array)) + + # Run plot_preds_serialized function in a separate process + p = Process( + target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs + ) + p.start() + return p def plot_preds( @@ -65,7 +99,7 @@ def plot_preds( rgb_values, fname: str, roll: bool = False, - inner_radius: float = 1 / 3, + radius: float = 1 / 2, dpi: int = 300, figsize: Tuple[float] = (6, 6), fsize: int = 0, @@ -113,12 +147,11 @@ def plot_preds( ax.set_yticks([]) ax.set_aspect("equal") ax.axis("off") - radius = 1 - ax.set_ylim(0, radius) + ax.set_ylim(0, 1) # implicit outer radius of 1 # Overlay white circle circle = patches.Circle( - (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 + (0, 0), radius, transform=ax.transData._b, color="white", zorder=2 ) ax.add_patch(circle) diff --git a/newsearch.py b/newsearch.py index 757ec5a..071cf9e 100644 --- a/newsearch.py +++ b/newsearch.py @@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401 # consistency of randomly sampled experiments. seed(19920921) -NUM_JOBS = 33 +NUM_JOBS = 1 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name @@ -33,12 +33,12 @@ depths = [1, 2, 4, 8, 16] # widths, depths = [512], [4] batch_size_values = [256] -max_epochs_values = [420] # at 12 fps, around 35s +max_epochs_values = [42] # at 12 fps, around 35s seeds = list(range(21, 1992)) optimizers = [ # "Adagrad", "Adam", - "SGD", + # "SGD", # "AdamW", # "LBFGS", # "RAdam", diff --git a/scripts/sortcolor.py b/scripts/sortcolor.py index 9ebf675..37c4a3b 100644 --- a/scripts/sortcolor.py +++ b/scripts/sortcolor.py @@ -220,7 +220,7 @@ def plot_preds( rgb_values, fname: str, roll: bool = False, - inner_radius: float = 1 / 3, + radius: float = 1 / 3, dpi: int = 300, figsize=(6, 6), ): @@ -266,12 +266,11 @@ def plot_preds( ax.set_yticks([]) ax.set_aspect("equal") ax.axis("off") - radius = 1 - ax.set_ylim(0, radius) + ax.set_ylim(0, 1) # Overlay white circle circle = patches.Circle( - (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 + (0, 0), radius, transform=ax.transData._b, color="white", zorder=2 ) ax.add_patch(circle) @@ -299,7 +298,7 @@ plot_preds( fname, roll=False, dpi=DPI, - inner_radius=INNER_RADIUS, + radius=INNER_RADIUS, figsize=(SIZE, SIZE), ) print(f"saved {fname}")