You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

130 lines
4.0 KiB

# import matplotlib.patches as patches
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
# import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots()
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
# ax.add_patch(rect)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", dpi=300)
def create_circle(ckpt: str, fname: str, dpi: int = 150):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
M = ckpt
rgb_tensor, _ = extract_colors()
preds = M(rgb_tensor.to(M.device))
plot_preds(preds, fname=fname, dpi=dpi)
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)):
rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))[0][0]
colors = np.roll(colors, -white_idx, axis=0)
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
theta[i],
1,
width=width,
edgecolor="none",
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
)
ax.set_xticks([])
ax.set_yticks([])
ax.axis("off")
ax.set_aspect("equal")
# Overlay white circle
radius = 1 / 3
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
fig.tight_layout(pad=0)
plt.savefig(f"{fname}.png", dpi=dpi, transparent=False)
plt.close()
if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
import argparse
import glob
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1])
parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image."
)
args = parser.parse_args()
versions = args.version
for v in versions:
name = f"out/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name, dpi=args.dpi)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False)