You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

185 lines
5.6 KiB

# import matplotlib.patches as patches
from pathlib import Path
from typing import Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
# import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds(
preds: torch.Tensor | np.ndarray,
rgb_values,
fname: str,
roll: bool = False,
inner_radius: float = 1 / 3,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
height=1,
width=width,
edgecolor=colors[i],
linewidth=0.25,
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
align="edge",
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(0, radius)
# Overlay white circle
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
plt.close()
if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
import argparse
import glob
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
)
parser.add_argument(
"--studio",
type=str,
default="this_studio",
nargs="+",
help="Checkpoint studio name.",
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
args = parser.parse_args()
versions = args.version
for studio in args.studio:
Path(studio).mkdir(exist_ok=True, parents=True)
for v in versions:
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(
ckpt,
fname=name,
dpi=args.dpi,
figsize=[args.figsize] * 2,
roll=False,
)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)