# import matplotlib.patches as patches import pickle from multiprocessing import Process from pathlib import Path from typing import Tuple, Union import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel # import matplotlib.colors as mcolors def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs): M = ColorTransformerModel.load_from_checkpoint(ckpt) # preds = M(rgb_tensor) if color is False: # black and white ordering... N = 949 linear_space = torch.linspace(0, 1, N) rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) else: rgb_tensor, names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor).to(M.device) preds = M(rgb_tensor).detach().cpu().numpy() rgb_values = rgb_tensor.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) fig, ax = plt.subplots() for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1) # rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1) # ax.add_patch(rect) ax.axis("off") # ax.axis("square") plt.savefig(f"{fname}.png", **kwargs) def do_inference(ckpt: Union[str, ColorTransformerModel]): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint( ckpt, map_location=lambda storage, loc: storage ) else: M = ckpt xkcd_colors, _ = extract_colors() xkcd_colors = preprocess_data(xkcd_colors).to(M.device) preds = M(xkcd_colors).detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy() return preds, rgb_array def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs): preds, rgb_array = do_inference(ckpt) plot_preds(preds, rgb_array, fname=fname, **kwargs) def _plot_preds_serialized(serialized_data, fname, **kwargs): # Deserialize the data preds, rgb_array = pickle.loads(serialized_data) plot_preds(preds, rgb_array, fname=fname, **kwargs) def create_circle_nonblocking( ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs ): preds, rgb_array = do_inference(ckpt) # Serialize the data serialized_data = pickle.dumps((preds, rgb_array)) # Run _plot_preds_serialized function in a separate process p = Process( target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs ) p.start() return p def plot_preds( preds: np.ndarray, rgb_values: np.ndarray, fname: str, roll: bool = False, radius: float = 1 / 2, dpi: int = 300, figsize: Tuple[float] = (6, 6), fsize: int = 0, label: str = "", ): sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: # find white in colors, put it first. white = np.array([1, 1, 1]) white_idx = np.where((colors == white).all(axis=1)) if white_idx: white_idx = white_idx[0][0] colors = np.roll(colors, -white_idx, axis=0) else: print("no white, skipping") # print(white_idx, colors[:2]) N = len(colors) # Create a plot with these hues in a circle fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Each wedge in the circle theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2 width = 2 * np.pi / (N) # equal size for each wedge for i in range(N): ax.bar( # 2 * np.pi * preds[i], theta[i], height=1, width=width, edgecolor=colors[i], linewidth=0.25, # facecolor=[rgb_values[i][1]]*3, # facecolor=rgb_values[i], facecolor=colors[i], bottom=0.0, zorder=1, alpha=1, align="edge", ) ax.set_xticks([]) ax.set_yticks([]) ax.set_aspect("equal") ax.axis("off") ax.set_ylim(0, 1) # implicit outer radius of 1 # Overlay white circle circle = patches.Circle( (0, 0), radius, transform=ax.transData._b, color="white", zorder=2 ) ax.add_patch(circle) if fsize > 0.0: center = (0, 0) ax.annotate( label, center, ha="center", va="center", size=fsize, color="black", ) fig.tight_layout(pad=0) plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight") plt.close() if __name__ == "__main__": import argparse import glob parser = argparse.ArgumentParser() parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument( "--dpi", type=int, default=300, help="Resolution for saved image." ) parser.add_argument( "--studio", type=str, default=["this_studio"], nargs="+", help="Checkpoint studio name.", ) parser.add_argument("--figsize", type=int, default=6, help="Figure size") args = parser.parse_args() versions = args.version for studio in args.studio: Path(studio).mkdir(exist_ok=True, parents=True) for v in versions: name = f"{studio}/v{v}" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt = glob.glob(ckpt_path) if len(ckpt) > 0: # get latest checkpoint ckpt = ckpt[-1] # TODO: allow specification via CLI print(f"Generating image for checkpoint: {ckpt}") create_circle( ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False, ) else: print(f"No checkpoint found for version {v}") # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)