Browse Source

draw circles

new-sep-loss
Michael Pilosov 11 months ago
parent
commit
e28753bc5b
  1. 55
      check.py

55
check.py

@ -6,6 +6,9 @@ import torch
from dataloader import extract_colors, preprocess_data from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel from model import ColorTransformerModel
import numpy as np
import matplotlib.pyplot as plt
# import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True): def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
@ -36,15 +39,53 @@ def make_image(ckpt: str, fname: str, color=True):
plt.savefig(f"{fname}.png", dpi=300) plt.savefig(f"{fname}.png", dpi=300)
def create_circle(ckpt: str, fname: str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
colors = rgb_values[sorted_inds]
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))[0][0]
colors = np.roll(colors, -white_idx, axis=0)
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N+1) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0)
ax.set_xticks([])
ax.set_yticks([])
ax.axis("off")
fig.tight_layout()
plt.savefig(f"{fname}.png", dpi=300)
if __name__ == "__main__": if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06" # name = "color_128_0.3_1.00e-06"
import argparse
import glob import glob
v = 29 parser = argparse.ArgumentParser()
name = f"v{v}" # make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1])
args = parser.parse_args()
versions = args.version
for v in versions:
name = f"out/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt = glob.glob( ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt = glob.glob(ckpt_path)[-1]
)[-1] print(f"Generating image for checkpoint: {ckpt}")
make_image(ckpt, fname=name) create_circle(ckpt, fname=name)
make_image(ckpt, fname=name + "b", color=False) # make_image(ckpt, fname=name + "b", color=False)

Loading…
Cancel
Save