colors/check.py

163 lines
5.2 KiB
Python
Raw Permalink Normal View History

2023-12-31 07:00:37 +00:00
# import matplotlib.patches as patches
2024-02-22 05:49:38 +00:00
from pathlib import Path
2024-02-22 07:37:33 +00:00
from typing import Union
2024-01-27 22:15:53 +00:00
2024-01-15 21:39:44 +00:00
import matplotlib.patches as patches
2023-12-30 06:35:19 +00:00
import matplotlib.pyplot as plt
import numpy as np
2023-12-31 05:20:28 +00:00
import torch
2023-12-30 06:35:19 +00:00
2023-12-31 06:17:15 +00:00
from dataloader import extract_colors, preprocess_data
2023-12-30 06:35:19 +00:00
from model import ColorTransformerModel
2023-12-31 21:05:54 +00:00
# import matplotlib.colors as mcolors
2023-12-31 06:17:15 +00:00
2024-01-09 18:19:38 +00:00
2024-01-25 06:12:27 +00:00
def make_image(ckpt: str, fname: str, color=True, **kwargs):
2023-12-31 05:20:28 +00:00
M = ColorTransformerModel.load_from_checkpoint(ckpt)
2023-12-30 06:35:19 +00:00
2023-12-31 05:20:28 +00:00
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
2023-12-30 06:35:19 +00:00
2023-12-31 05:20:28 +00:00
rgb_values = rgb_tensor.detach().numpy()
2023-12-31 06:17:15 +00:00
rgb_tensor = preprocess_data(rgb_tensor)
2023-12-31 05:20:28 +00:00
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
2023-12-30 06:35:19 +00:00
2023-12-31 07:00:37 +00:00
fig, ax = plt.subplots()
2023-12-31 05:20:28 +00:00
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
2023-12-31 07:00:37 +00:00
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
2023-12-31 05:20:28 +00:00
ax.axis("off")
# ax.axis("square")
2023-12-30 06:35:19 +00:00
2024-01-25 06:12:27 +00:00
plt.savefig(f"{fname}.png", **kwargs)
2023-12-31 05:20:28 +00:00
2024-01-27 22:15:53 +00:00
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
2024-01-14 07:28:23 +00:00
if isinstance(ckpt, str):
2024-02-22 07:37:33 +00:00
M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage
)
2024-01-14 07:28:23 +00:00
else:
M = ckpt
2024-01-27 22:15:53 +00:00
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
2024-01-16 03:05:26 +00:00
def plot_preds(
2024-02-23 16:45:22 +00:00
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6)
2024-01-16 03:05:26 +00:00
):
if isinstance(preds, torch.Tensor):
2024-01-15 07:00:50 +00:00
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
2024-01-16 04:37:22 +00:00
colors = rgb_values[sorted_inds, :3]
2024-01-15 05:29:12 +00:00
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
2024-01-16 03:05:26 +00:00
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
2024-01-15 05:29:12 +00:00
# print(white_idx, colors[:2])
2023-12-31 21:05:54 +00:00
N = len(colors)
# Create a plot with these hues in a circle
2024-01-15 21:05:31 +00:00
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
2023-12-31 21:05:54 +00:00
# Each wedge in the circle
2024-01-15 21:39:44 +00:00
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
2023-12-31 21:05:54 +00:00
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
2024-01-15 21:39:44 +00:00
ax.bar(
2024-01-16 03:05:26 +00:00
# 2 * np.pi * preds[i],
2024-01-15 21:39:44 +00:00
theta[i],
2024-01-16 03:05:26 +00:00
height=1,
2024-01-15 21:39:44 +00:00
width=width,
2024-01-26 03:02:33 +00:00
edgecolor=colors[i],
linewidth=0.25,
2024-01-16 03:05:26 +00:00
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
2024-01-15 21:39:44 +00:00
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
2024-01-27 05:41:31 +00:00
align="edge",
2024-01-15 21:39:44 +00:00
)
2023-12-31 21:05:54 +00:00
ax.set_xticks([])
ax.set_yticks([])
2024-01-15 21:39:44 +00:00
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(-radius, radius)
2024-01-15 21:05:31 +00:00
# Overlay white circle
inner_radius = 1 / 3
2024-01-15 21:39:44 +00:00
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
2024-01-15 21:39:44 +00:00
)
2024-01-15 21:05:31 +00:00
ax.add_patch(circle)
fig.tight_layout(pad=0)
2024-01-15 21:39:44 +00:00
plt.savefig(
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
)
2024-01-09 18:19:38 +00:00
plt.close()
2023-12-31 21:05:54 +00:00
2023-12-31 05:20:28 +00:00
if __name__ == "__main__":
2023-12-31 06:17:15 +00:00
# name = "color_128_0.3_1.00e-06"
2023-12-31 21:05:54 +00:00
import argparse
2023-12-31 07:00:37 +00:00
import glob
2023-12-31 21:05:54 +00:00
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
2024-01-25 06:12:27 +00:00
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
2024-01-15 19:02:26 +00:00
parser.add_argument(
2024-02-22 05:49:38 +00:00
"--dpi", type=int, default=300, help="Resolution for saved image."
2024-01-15 19:02:26 +00:00
)
2024-02-22 05:49:38 +00:00
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
2023-12-31 21:05:54 +00:00
args = parser.parse_args()
versions = args.version
for v in versions:
2024-02-22 05:49:38 +00:00
# name = f"out/v{v}"
studio = "colors-refactor-supervised"
# studio = "colors-refactor-unsupervised"
# studio = "colors-refactor-unsupervised-anchors"
2024-02-22 07:08:07 +00:00
# studio = "this_studio"
2024-02-22 05:49:38 +00:00
Path(studio).mkdir(exist_ok=True, parents=True)
name = f"{studio}/v{v}"
2023-12-31 21:05:54 +00:00
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
2024-02-22 05:49:38 +00:00
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
2024-01-14 06:04:19 +00:00
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
2024-02-22 07:37:33 +00:00
create_circle(
ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False
)
2024-01-14 06:04:19 +00:00
else:
print(f"No checkpoint found for version {v}")
2024-01-25 06:12:27 +00:00
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)