diff --git a/check.py b/check.py index 3559327..1fb976d 100644 --- a/check.py +++ b/check.py @@ -6,6 +6,9 @@ import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel +import numpy as np +import matplotlib.pyplot as plt +# import matplotlib.colors as mcolors def make_image(ckpt: str, fname: str, color=True): M = ColorTransformerModel.load_from_checkpoint(ckpt) @@ -36,15 +39,53 @@ def make_image(ckpt: str, fname: str, color=True): plt.savefig(f"{fname}.png", dpi=300) +def create_circle(ckpt: str, fname: str): + M = ColorTransformerModel.load_from_checkpoint(ckpt) + rgb_tensor, names = extract_colors() + rgb_values = rgb_tensor.detach().numpy() + rgb_tensor = preprocess_data(rgb_tensor) + preds = M(rgb_tensor) + sorted_inds = np.argsort(preds.detach().numpy().ravel()) + colors = rgb_values[sorted_inds] + # find white in colors, put it first. + white = np.array([1, 1, 1]) + white_idx = np.where((colors == white).all(axis=1))[0][0] + colors = np.roll(colors, -white_idx, axis=0) + # print(white_idx, colors[:2]) + + N = len(colors) + # Create a plot with these hues in a circle + fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) + + # Each wedge in the circle + theta = np.linspace(0, 2 * np.pi, N+1) + np.pi / 2 + width = 2 * np.pi / (N) # equal size for each wedge + + for i in range(N): + ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.axis("off") + fig.tight_layout() + plt.savefig(f"{fname}.png", dpi=300) + + if __name__ == "__main__": # name = "color_128_0.3_1.00e-06" + import argparse import glob - v = 29 - name = f"v{v}" - # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" - ckpt = glob.glob( - f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" - )[-1] - make_image(ckpt, fname=name) - make_image(ckpt, fname=name + "b", color=False) + parser = argparse.ArgumentParser() + # make the following accept a list of arguments + parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) + args = parser.parse_args() + versions = args.version + for v in versions: + name = f"out/v{v}" + # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" + ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" + ckpt = glob.glob(ckpt_path)[-1] + print(f"Generating image for checkpoint: {ckpt}") + create_circle(ckpt, fname=name) + # make_image(ckpt, fname=name + "b", color=False)