|
|
|
# import matplotlib.patches as patches
|
|
|
|
import pickle
|
|
|
|
from multiprocessing import Process
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Tuple, Union
|
|
|
|
|
|
|
|
import matplotlib.patches as patches
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from dataloader import extract_colors, preprocess_data
|
|
|
|
from model import ColorTransformerModel
|
|
|
|
|
|
|
|
# import matplotlib.colors as mcolors
|
|
|
|
|
|
|
|
|
|
|
|
def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs):
|
|
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt)
|
|
|
|
|
|
|
|
# preds = M(rgb_tensor)
|
|
|
|
if color is False: # black and white ordering...
|
|
|
|
N = 949
|
|
|
|
linear_space = torch.linspace(0, 1, N)
|
|
|
|
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
|
|
|
|
else:
|
|
|
|
rgb_tensor, names = extract_colors()
|
|
|
|
|
|
|
|
rgb_tensor = preprocess_data(rgb_tensor).to(M.device)
|
|
|
|
preds = M(rgb_tensor).detach().cpu().numpy()
|
|
|
|
rgb_values = rgb_tensor.detach().cpu().numpy()
|
|
|
|
sorted_inds = np.argsort(preds.ravel())
|
|
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
for i in range(len(sorted_inds)):
|
|
|
|
idx = sorted_inds[i]
|
|
|
|
color = rgb_values[idx]
|
|
|
|
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
|
|
|
|
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
|
|
|
|
# ax.add_patch(rect)
|
|
|
|
ax.axis("off")
|
|
|
|
# ax.axis("square")
|
|
|
|
|
|
|
|
plt.savefig(f"{fname}.png", **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def do_inference(ckpt: Union[str, ColorTransformerModel]):
|
|
|
|
|
|
|
|
if isinstance(ckpt, str):
|
|
|
|
M = ColorTransformerModel.load_from_checkpoint(
|
|
|
|
ckpt, map_location=lambda storage, loc: storage
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
M = ckpt
|
|
|
|
|
|
|
|
xkcd_colors, _ = extract_colors()
|
|
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
|
|
|
|
preds = M(xkcd_colors).detach().cpu().numpy()
|
|
|
|
rgb_array = xkcd_colors.detach().cpu().numpy()
|
|
|
|
return preds, rgb_array
|
|
|
|
|
|
|
|
|
|
|
|
def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs):
|
|
|
|
preds, rgb_array = do_inference(ckpt)
|
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def _plot_preds_serialized(serialized_data, fname, **kwargs):
|
|
|
|
# Deserialize the data
|
|
|
|
preds, rgb_array = pickle.loads(serialized_data)
|
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def create_circle_nonblocking(
|
|
|
|
ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs
|
|
|
|
):
|
|
|
|
preds, rgb_array = do_inference(ckpt)
|
|
|
|
|
|
|
|
# Serialize the data
|
|
|
|
serialized_data = pickle.dumps((preds, rgb_array))
|
|
|
|
|
|
|
|
# Run _plot_preds_serialized function in a separate process
|
|
|
|
p = Process(
|
|
|
|
target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
|
|
|
|
)
|
|
|
|
p.start()
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(
|
|
|
|
preds: np.ndarray,
|
|
|
|
rgb_values: np.ndarray,
|
|
|
|
fname: str,
|
|
|
|
roll: bool = False,
|
|
|
|
radius: float = 1 / 2,
|
|
|
|
dpi: int = 300,
|
|
|
|
figsize: Tuple[float] = (6, 6),
|
|
|
|
fsize: int = 0,
|
|
|
|
label: str = "",
|
|
|
|
):
|
|
|
|
sorted_inds = np.argsort(preds.ravel())
|
|
|
|
colors = rgb_values[sorted_inds, :3]
|
|
|
|
if roll:
|
|
|
|
# find white in colors, put it first.
|
|
|
|
white = np.array([1, 1, 1])
|
|
|
|
white_idx = np.where((colors == white).all(axis=1))
|
|
|
|
if white_idx:
|
|
|
|
white_idx = white_idx[0][0]
|
|
|
|
colors = np.roll(colors, -white_idx, axis=0)
|
|
|
|
else:
|
|
|
|
print("no white, skipping")
|
|
|
|
# print(white_idx, colors[:2])
|
|
|
|
|
|
|
|
N = len(colors)
|
|
|
|
# Create a plot with these hues in a circle
|
|
|
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
|
|
|
|
|
|
|
# Each wedge in the circle
|
|
|
|
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
|
|
|
|
width = 2 * np.pi / (N) # equal size for each wedge
|
|
|
|
|
|
|
|
for i in range(N):
|
|
|
|
ax.bar(
|
|
|
|
# 2 * np.pi * preds[i],
|
|
|
|
theta[i],
|
|
|
|
height=1,
|
|
|
|
width=width,
|
|
|
|
edgecolor=colors[i],
|
|
|
|
linewidth=0.25,
|
|
|
|
# facecolor=[rgb_values[i][1]]*3,
|
|
|
|
# facecolor=rgb_values[i],
|
|
|
|
facecolor=colors[i],
|
|
|
|
bottom=0.0,
|
|
|
|
zorder=1,
|
|
|
|
alpha=1,
|
|
|
|
align="edge",
|
|
|
|
)
|
|
|
|
|
|
|
|
ax.set_xticks([])
|
|
|
|
ax.set_yticks([])
|
|
|
|
ax.set_aspect("equal")
|
|
|
|
ax.axis("off")
|
|
|
|
ax.set_ylim(0, 1) # implicit outer radius of 1
|
|
|
|
|
|
|
|
# Overlay white circle
|
|
|
|
circle = patches.Circle(
|
|
|
|
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
|
|
|
|
)
|
|
|
|
ax.add_patch(circle)
|
|
|
|
|
|
|
|
if fsize > 0.0:
|
|
|
|
center = (0, 0)
|
|
|
|
ax.annotate(
|
|
|
|
label,
|
|
|
|
center,
|
|
|
|
ha="center",
|
|
|
|
va="center",
|
|
|
|
size=fsize,
|
|
|
|
color="black",
|
|
|
|
)
|
|
|
|
|
|
|
|
fig.tight_layout(pad=0)
|
|
|
|
|
|
|
|
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import argparse
|
|
|
|
import glob
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
|
|
|
|
parser.add_argument(
|
|
|
|
"--dpi", type=int, default=300, help="Resolution for saved image."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--studio",
|
|
|
|
type=str,
|
|
|
|
default=["this_studio"],
|
|
|
|
nargs="+",
|
|
|
|
help="Checkpoint studio name.",
|
|
|
|
)
|
|
|
|
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
|
|
|
|
args = parser.parse_args()
|
|
|
|
versions = args.version
|
|
|
|
for studio in args.studio:
|
|
|
|
Path(studio).mkdir(exist_ok=True, parents=True)
|
|
|
|
for v in versions:
|
|
|
|
name = f"{studio}/v{v}"
|
|
|
|
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
|
|
|
|
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
|
|
|
|
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
|
|
|
|
ckpt = glob.glob(ckpt_path)
|
|
|
|
if len(ckpt) > 0: # get latest checkpoint
|
|
|
|
ckpt = ckpt[-1] # TODO: allow specification via CLI
|
|
|
|
print(f"Generating image for checkpoint: {ckpt}")
|
|
|
|
create_circle(
|
|
|
|
ckpt,
|
|
|
|
fname=name,
|
|
|
|
dpi=args.dpi,
|
|
|
|
figsize=[args.figsize] * 2,
|
|
|
|
roll=False,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
print(f"No checkpoint found for version {v}")
|
|
|
|
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)
|