# import matplotlib.patches as patches from pathlib import Path from typing import Tuple, Union import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel # import matplotlib.colors as mcolors def make_image(ckpt: str, fname: str, color=True, **kwargs): M = ColorTransformerModel.load_from_checkpoint(ckpt) # preds = M(rgb_tensor) if not color: N = 949 linear_space = torch.linspace(0, 1, N) rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) else: rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor) sorted_inds = np.argsort(preds.detach().numpy().ravel()) fig, ax = plt.subplots() for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1) # rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1) # ax.add_patch(rect) ax.axis("off") # ax.axis("square") plt.savefig(f"{fname}.png", **kwargs) def create_circle( ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs ): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint( ckpt, map_location=lambda storage, loc: storage ) else: M = ckpt xkcd_colors, _ = extract_colors() xkcd_colors = preprocess_data(xkcd_colors).to(M.device) preds = M(xkcd_colors) rgb_array = xkcd_colors.detach().cpu().numpy() plot_preds(preds, rgb_array, fname=fname, **kwargs) def plot_preds( preds: torch.Tensor | np.ndarray, rgb_values, fname: str, roll: bool = False, inner_radius: float = 1 / 3, dpi: int = 300, figsize: Tuple[float] = (6, 6), fsize: int = 0, label: str = "", ): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: # find white in colors, put it first. white = np.array([1, 1, 1]) white_idx = np.where((colors == white).all(axis=1)) if white_idx: white_idx = white_idx[0][0] colors = np.roll(colors, -white_idx, axis=0) else: print("no white, skipping") # print(white_idx, colors[:2]) N = len(colors) # Create a plot with these hues in a circle fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Each wedge in the circle theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2 width = 2 * np.pi / (N) # equal size for each wedge for i in range(N): ax.bar( # 2 * np.pi * preds[i], theta[i], height=1, width=width, edgecolor=colors[i], linewidth=0.25, # facecolor=[rgb_values[i][1]]*3, # facecolor=rgb_values[i], facecolor=colors[i], bottom=0.0, zorder=1, alpha=1, align="edge", ) ax.set_xticks([]) ax.set_yticks([]) ax.set_aspect("equal") ax.axis("off") radius = 1 ax.set_ylim(0, radius) # Overlay white circle circle = patches.Circle( (0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 ) ax.add_patch(circle) if fsize > 0.0: center = (0, 0) ax.annotate( label, center, ha="center", va="center", size=fsize, color="black", ) fig.tight_layout(pad=0) plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight") plt.close() if __name__ == "__main__": # name = "color_128_0.3_1.00e-06" import argparse import glob parser = argparse.ArgumentParser() # make the following accept a list of arguments parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument( "--dpi", type=int, default=300, help="Resolution for saved image." ) parser.add_argument( "--studio", type=str, default="this_studio", nargs="+", help="Checkpoint studio name.", ) parser.add_argument("--figsize", type=int, default=6, help="Figure size") args = parser.parse_args() versions = args.version for studio in args.studio: Path(studio).mkdir(exist_ok=True, parents=True) for v in versions: name = f"{studio}/v{v}" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt = glob.glob(ckpt_path) if len(ckpt) > 0: ckpt = ckpt[-1] print(f"Generating image for checkpoint: {ckpt}") create_circle( ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False, ) else: print(f"No checkpoint found for version {v}") # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)