You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
6.3 KiB

11 months ago
# import matplotlib.patches as patches
import pickle
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
11 months ago
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
11 months ago
# import matplotlib.colors as mcolors
11 months ago
11 months ago
def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if color is False: # black and white ordering...
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor).to(M.device)
preds = M(rgb_tensor).detach().cpu().numpy()
rgb_values = rgb_tensor.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
11 months ago
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
11 months ago
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
10 months ago
plt.savefig(f"{fname}.png", **kwargs)
def do_inference(ckpt: Union[str, ColorTransformerModel]):
10 months ago
if isinstance(ckpt, str):
9 months ago
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
10 months ago
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
9 months ago
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
return preds, rgb_array
def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs):
preds, rgb_array = do_inference(ckpt)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def _plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle_nonblocking(
ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs
):
preds, rgb_array = do_inference(ckpt)
# Serialize the data
serialized_data = pickle.dumps((preds, rgb_array))
# Run _plot_preds_serialized function in a separate process
p = Process(
target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
)
p.start()
return p
10 months ago
def plot_preds(
9 months ago
preds: np.ndarray,
rgb_values: np.ndarray,
fname: str,
roll: bool = False,
radius: float = 1 / 2,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
10 months ago
):
sorted_inds = np.argsort(preds.ravel())
10 months ago
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
10 months ago
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
11 months ago
N = len(colors)
# Create a plot with these hues in a circle
10 months ago
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
11 months ago
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
11 months ago
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
10 months ago
# 2 * np.pi * preds[i],
theta[i],
10 months ago
height=1,
width=width,
10 months ago
edgecolor=colors[i],
linewidth=0.25,
10 months ago
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
10 months ago
align="edge",
)
11 months ago
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
ax.set_ylim(0, 1) # implicit outer radius of 1
10 months ago
# Overlay white circle
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
10 months ago
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
10 months ago
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
11 months ago
plt.close()
11 months ago
if __name__ == "__main__":
11 months ago
import argparse
11 months ago
import glob
11 months ago
parser = argparse.ArgumentParser()
10 months ago
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
10 months ago
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
10 months ago
)
parser.add_argument(
"--studio",
type=str,
default=["this_studio"],
nargs="+",
help="Checkpoint studio name.",
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
11 months ago
args = parser.parse_args()
versions = args.version
for studio in args.studio:
Path(studio).mkdir(exist_ok=True, parents=True)
for v in versions:
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0: # get latest checkpoint
ckpt = ckpt[-1] # TODO: allow specification via CLI
print(f"Generating image for checkpoint: {ckpt}")
create_circle(
ckpt,
fname=name,
dpi=args.dpi,
figsize=[args.figsize] * 2,
roll=False,
)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)