You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
6.5 KiB

11 months ago
# import matplotlib.patches as patches
import pickle
from multiprocessing import Process
from pathlib import Path
from typing import Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
11 months ago
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
11 months ago
# import matplotlib.colors as mcolors
11 months ago
11 months ago
10 months ago
def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
11 months ago
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
11 months ago
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
11 months ago
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
10 months ago
plt.savefig(f"{fname}.png", **kwargs)
# def create_circle(
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
# ):
# if isinstance(ckpt, str):
# M = ColorTransformerModel.load_from_checkpoint(
# ckpt, map_location=lambda storage, loc: storage
# )
# else:
# M = ckpt
# xkcd_colors, _ = extract_colors()
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
# preds = M(xkcd_colors).detach().cpu().numpy()
# rgb_array = xkcd_colors.detach().cpu().numpy()
# plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
11 months ago
if isinstance(ckpt, str):
9 months ago
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
11 months ago
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
9 months ago
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
# Serialize the data
serialized_data = pickle.dumps((preds, rgb_array))
# Run plot_preds_serialized function in a separate process
p = Process(
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
)
p.start()
return p
10 months ago
def plot_preds(
9 months ago
preds: np.ndarray,
rgb_values,
fname: str,
roll: bool = False,
radius: float = 1 / 2,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
10 months ago
):
sorted_inds = np.argsort(preds.ravel())
10 months ago
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
10 months ago
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
11 months ago
N = len(colors)
# Create a plot with these hues in a circle
10 months ago
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
11 months ago
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
11 months ago
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
10 months ago
# 2 * np.pi * preds[i],
theta[i],
10 months ago
height=1,
width=width,
10 months ago
edgecolor=colors[i],
linewidth=0.25,
10 months ago
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
10 months ago
align="edge",
)
11 months ago
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
ax.set_ylim(0, 1) # implicit outer radius of 1
10 months ago
# Overlay white circle
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
10 months ago
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
10 months ago
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
11 months ago
plt.close()
11 months ago
if __name__ == "__main__":
11 months ago
# name = "color_128_0.3_1.00e-06"
11 months ago
import argparse
11 months ago
import glob
11 months ago
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
10 months ago
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
10 months ago
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
10 months ago
)
parser.add_argument(
"--studio",
type=str,
default="this_studio",
nargs="+",
help="Checkpoint studio name.",
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
11 months ago
args = parser.parse_args()
versions = args.version
for studio in args.studio:
Path(studio).mkdir(exist_ok=True, parents=True)
for v in versions:
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(
ckpt,
fname=name,
dpi=args.dpi,
figsize=[args.figsize] * 2,
roll=False,
)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)