You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

216 lines
6.5 KiB

# import matplotlib.patches as patches
import pickle
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
# import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", **kwargs)
# def create_circle(
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
# ):
# if isinstance(ckpt, str):
# M = ColorTransformerModel.load_from_checkpoint(
# ckpt, map_location=lambda storage, loc: storage
# )
# else:
# M = ckpt
# xkcd_colors, _ = extract_colors()
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
# preds = M(xkcd_colors).detach().cpu().numpy()
# rgb_array = xkcd_colors.detach().cpu().numpy()
# plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
# Serialize the data
serialized_data = pickle.dumps((preds, rgb_array))
# Run plot_preds_serialized function in a separate process
p = Process(
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
)
p.start()
return p
def plot_preds(
preds: np.ndarray,
rgb_values,
fname: str,
roll: bool = False,
radius: float = 1 / 2,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
):
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
height=1,
width=width,
edgecolor=colors[i],
linewidth=0.25,
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
align="edge",
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
ax.set_ylim(0, 1) # implicit outer radius of 1
# Overlay white circle
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
plt.close()
if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
import argparse
import glob
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
)
parser.add_argument(
"--studio",
type=str,
default="this_studio",
nargs="+",
help="Checkpoint studio name.",
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
args = parser.parse_args()
versions = args.version
for studio in args.studio:
Path(studio).mkdir(exist_ok=True, parents=True)
for v in versions:
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(
ckpt,
fname=name,
dpi=args.dpi,
figsize=[args.figsize] * 2,
roll=False,
)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)