You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

207 lines
6.3 KiB

# import matplotlib.patches as patches
import pickle
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
# import matplotlib.colors as mcolors
def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if color is False: # black and white ordering...
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor).to(M.device)
preds = M(rgb_tensor).detach().cpu().numpy()
rgb_values = rgb_tensor.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", **kwargs)
def do_inference(ckpt: Union[str, ColorTransformerModel]):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
return preds, rgb_array
def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs):
preds, rgb_array = do_inference(ckpt)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def _plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle_nonblocking(
ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs
):
preds, rgb_array = do_inference(ckpt)
# Serialize the data
serialized_data = pickle.dumps((preds, rgb_array))
# Run _plot_preds_serialized function in a separate process
p = Process(
target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
)
p.start()
return p
def plot_preds(
preds: np.ndarray,
rgb_values: np.ndarray,
fname: str,
roll: bool = False,
radius: float = 1 / 2,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
):
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
height=1,
width=width,
edgecolor=colors[i],
linewidth=0.25,
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
align="edge",
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
ax.set_ylim(0, 1) # implicit outer radius of 1
# Overlay white circle
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
plt.close()
if __name__ == "__main__":
import argparse
import glob
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
)
parser.add_argument(
"--studio",
type=str,
default=["this_studio"],
nargs="+",
help="Checkpoint studio name.",
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
args = parser.parse_args()
versions = args.version
for studio in args.studio:
Path(studio).mkdir(exist_ok=True, parents=True)
for v in versions:
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0: # get latest checkpoint
ckpt = ckpt[-1] # TODO: allow specification via CLI
print(f"Generating image for checkpoint: {ckpt}")
create_circle(
ckpt,
fname=name,
dpi=args.dpi,
figsize=[args.figsize] * 2,
roll=False,
)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)