Browse Source

much faster animations

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
2582bdbbec
  1. 2
      baseline.py
  2. 52
      callbacks.py
  3. 45
      check.py
  4. 6
      newsearch.py
  5. 9
      scripts/sortcolor.py

2
baseline.py

@ -120,7 +120,7 @@ plot_preds(
fname,
roll=False,
dpi=DPI,
inner_radius=INNER_RADIUS,
radius=INNER_RADIUS,
figsize=(SIZE, SIZE),
fsize=FONTSIZE,
label=f"{KIND.upper()}",

52
callbacks.py

@ -1,3 +1,4 @@
import subprocess
from pathlib import Path
from lightning import Callback
@ -6,9 +7,20 @@ from check import create_circle
class SaveImageCallback(Callback):
def __init__(self, save_interval=1, final_dir: str = None):
def __init__(
self,
save_interval=1,
final_dir: str = None,
radius: float = 0.5,
dpi: int = 300,
figsize=(15, 15), # 4k is at 14.4
):
self.save_interval = save_interval
self.final_dir = final_dir
self.radius = radius
self.dpi = dpi
self.figsize = figsize
self.processes = []
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
@ -20,12 +32,15 @@ class SaveImageCallback(Callback):
pl_module.eval()
# Save the image
# if pl_module.trainer.logger:
#
# else:
# version = 0
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6))
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
# Make sure to set it back to train mode
pl_module.train()
@ -35,14 +50,23 @@ class SaveImageCallback(Callback):
version = pl_module.trainer.logger.version
fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval()
create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6))
if self.save_interval > 0:
import os
p = create_circle(
pl_module,
fname=fname,
dpi=self.dpi,
figsize=self.figsize,
radius=self.radius,
)
self.processes.append(p)
# Wait for all subprocesses to finish
for p in self.processes:
p.join()
if self.save_interval > 0:
log_dir = str(Path(pl_module.trainer.logger.log_dir))
fps = 12
_cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale=1920:1080:force_original_aspect_ratio=decrease,pad=1920:1080:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4'
os.system(_cmd)
# os.system(
# f'ffmpeg -i {log_dir}/e%04d.png -c:v libx264 -vf "fps={fps},format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" {log_dir}/a{version}.mp4'
# )
# w, h = self.figsize[0] * self.dpi, self.figsize[1] * self.dpi
w, h = 7680, 4320
_cmd = f'ffmpeg -r {fps} -f image2 -i {log_dir}/e%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2:color=white" {log_dir}/a{version}.mp4'
_ = subprocess.Popen(_cmd, shell=True)

45
check.py

@ -1,4 +1,6 @@
# import matplotlib.patches as patches
import pickle
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union
@ -42,11 +44,34 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs):
plt.savefig(f"{fname}.png", **kwargs)
# def create_circle(
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
# ):
# if isinstance(ckpt, str):
# M = ColorTransformerModel.load_from_checkpoint(
# ckpt, map_location=lambda storage, loc: storage
# )
# else:
# M = ckpt
# xkcd_colors, _ = extract_colors()
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
# preds = M(xkcd_colors).detach().cpu().numpy()
# rgb_array = xkcd_colors.detach().cpu().numpy()
# plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
@ -57,7 +82,16 @@ def create_circle(
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
# Serialize the data
serialized_data = pickle.dumps((preds, rgb_array))
# Run plot_preds_serialized function in a separate process
p = Process(
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
)
p.start()
return p
def plot_preds(
@ -65,7 +99,7 @@ def plot_preds(
rgb_values,
fname: str,
roll: bool = False,
inner_radius: float = 1 / 3,
radius: float = 1 / 2,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
@ -113,12 +147,11 @@ def plot_preds(
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(0, radius)
ax.set_ylim(0, 1) # implicit outer radius of 1
# Overlay white circle
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)

6
newsearch.py

@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401
# consistency of randomly sampled experiments.
seed(19920921)
NUM_JOBS = 33
NUM_JOBS = 1
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
@ -33,12 +33,12 @@ depths = [1, 2, 4, 8, 16]
# widths, depths = [512], [4]
batch_size_values = [256]
max_epochs_values = [420] # at 12 fps, around 35s
max_epochs_values = [42] # at 12 fps, around 35s
seeds = list(range(21, 1992))
optimizers = [
# "Adagrad",
"Adam",
"SGD",
# "SGD",
# "AdamW",
# "LBFGS",
# "RAdam",

9
scripts/sortcolor.py

@ -220,7 +220,7 @@ def plot_preds(
rgb_values,
fname: str,
roll: bool = False,
inner_radius: float = 1 / 3,
radius: float = 1 / 3,
dpi: int = 300,
figsize=(6, 6),
):
@ -266,12 +266,11 @@ def plot_preds(
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(0, radius)
ax.set_ylim(0, 1)
# Overlay white circle
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
@ -299,7 +298,7 @@ plot_preds(
fname,
roll=False,
dpi=DPI,
inner_radius=INNER_RADIUS,
radius=INNER_RADIUS,
figsize=(SIZE, SIZE),
)
print(f"saved {fname}")

Loading…
Cancel
Save