You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.1 KiB

11 months ago
# import matplotlib.patches as patches
from typing import Union
from pathlib import Path
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
11 months ago
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
11 months ago
# import matplotlib.colors as mcolors
11 months ago
11 months ago
10 months ago
def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
11 months ago
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
11 months ago
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
11 months ago
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
10 months ago
plt.savefig(f"{fname}.png", **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
10 months ago
if isinstance(ckpt, str):
import yaml
M = ColorTransformerModel.load_from_checkpoint(ckpt, map_location=lambda storage, loc: storage)
10 months ago
else:
M = ckpt
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
10 months ago
def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)
):
if isinstance(preds, torch.Tensor):
10 months ago
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
10 months ago
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
10 months ago
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
11 months ago
N = len(colors)
# Create a plot with these hues in a circle
10 months ago
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
11 months ago
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
11 months ago
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
10 months ago
# 2 * np.pi * preds[i],
theta[i],
10 months ago
height=1,
width=width,
10 months ago
edgecolor=colors[i],
linewidth=0.25,
10 months ago
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
10 months ago
align="edge",
)
11 months ago
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(-radius, radius)
10 months ago
# Overlay white circle
inner_radius = 1 / 3
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
)
10 months ago
ax.add_patch(circle)
fig.tight_layout(pad=0)
plt.savefig(
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
)
11 months ago
plt.close()
11 months ago
if __name__ == "__main__":
11 months ago
# name = "color_128_0.3_1.00e-06"
11 months ago
import argparse
11 months ago
import glob
11 months ago
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
10 months ago
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
10 months ago
parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image."
10 months ago
)
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
11 months ago
args = parser.parse_args()
versions = args.version
for v in versions:
# name = f"out/v{v}"
studio = "colors-refactor-supervised"
# studio = "colors-refactor-unsupervised"
# studio = "colors-refactor-unsupervised-anchors"
9 months ago
# studio = "this_studio"
Path(studio).mkdir(exist_ok=True, parents=True)
name = f"{studio}/v{v}"
11 months ago
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False)
else:
print(f"No checkpoint found for version {v}")
10 months ago
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)