# import matplotlib.patches as patches from typing import Union from pathlib import Path import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel # import matplotlib.colors as mcolors def make_image(ckpt: str, fname: str, color=True, **kwargs): M = ColorTransformerModel.load_from_checkpoint(ckpt) # preds = M(rgb_tensor) if not color: N = 949 linear_space = torch.linspace(0, 1, N) rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) else: rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor) sorted_inds = np.argsort(preds.detach().numpy().ravel()) fig, ax = plt.subplots() for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1) # rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1) # ax.add_patch(rect) ax.axis("off") # ax.axis("square") plt.savefig(f"{fname}.png", **kwargs) def create_circle( ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs ): if isinstance(ckpt, str): import yaml M = ColorTransformerModel.load_from_checkpoint(ckpt, map_location=lambda storage, loc: storage) else: M = ckpt xkcd_colors, _ = extract_colors() xkcd_colors = preprocess_data(xkcd_colors).to(M.device) preds = M(xkcd_colors) rgb_array = xkcd_colors.detach().cpu().numpy() plot_preds(preds, rgb_array, fname=fname, **kwargs) def plot_preds( preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3) ): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: # find white in colors, put it first. white = np.array([1, 1, 1]) white_idx = np.where((colors == white).all(axis=1)) if white_idx: white_idx = white_idx[0][0] colors = np.roll(colors, -white_idx, axis=0) else: print("no white, skipping") # print(white_idx, colors[:2]) N = len(colors) # Create a plot with these hues in a circle fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Each wedge in the circle theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2 width = 2 * np.pi / (N) # equal size for each wedge for i in range(N): ax.bar( # 2 * np.pi * preds[i], theta[i], height=1, width=width, edgecolor=colors[i], linewidth=0.25, # facecolor=[rgb_values[i][1]]*3, # facecolor=rgb_values[i], facecolor=colors[i], bottom=0.0, zorder=1, alpha=1, align="edge", ) ax.set_xticks([]) ax.set_yticks([]) ax.set_aspect("equal") ax.axis("off") radius = 1 ax.set_ylim(-radius, radius) # Overlay white circle inner_radius = 1 / 3 circle = patches.Circle( (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 ) ax.add_patch(circle) fig.tight_layout(pad=0) plt.savefig( f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight" ) plt.close() if __name__ == "__main__": # name = "color_128_0.3_1.00e-06" import argparse import glob parser = argparse.ArgumentParser() # make the following accept a list of arguments parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument( "--dpi", type=int, default=300, help="Resolution for saved image." ) parser.add_argument("--figsize", type=int, default=6, help="Figure size") args = parser.parse_args() versions = args.version for v in versions: # name = f"out/v{v}" studio = "colors-refactor-supervised" # studio = "colors-refactor-unsupervised" # studio = "colors-refactor-unsupervised-anchors" # studio = "this_studio" Path(studio).mkdir(exist_ok=True, parents=True) name = f"{studio}/v{v}" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt = glob.glob(ckpt_path) if len(ckpt) > 0: ckpt = ckpt[-1] print(f"Generating image for checkpoint: {ckpt}") create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False) else: print(f"No checkpoint found for version {v}") # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)