|
|
@ -6,6 +6,9 @@ import torch |
|
|
|
from dataloader import extract_colors, preprocess_data |
|
|
|
from model import ColorTransformerModel |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
# import matplotlib.colors as mcolors |
|
|
|
|
|
|
|
def make_image(ckpt: str, fname: str, color=True): |
|
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
@ -36,15 +39,53 @@ def make_image(ckpt: str, fname: str, color=True): |
|
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
|
|
|
|
|
|
|
|
|
def create_circle(ckpt: str, fname: str): |
|
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
|
rgb_tensor, names = extract_colors() |
|
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
preds = M(rgb_tensor) |
|
|
|
sorted_inds = np.argsort(preds.detach().numpy().ravel()) |
|
|
|
colors = rgb_values[sorted_inds] |
|
|
|
# find white in colors, put it first. |
|
|
|
white = np.array([1, 1, 1]) |
|
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|
|
colors = np.roll(colors, -white_idx, axis=0) |
|
|
|
# print(white_idx, colors[:2]) |
|
|
|
|
|
|
|
N = len(colors) |
|
|
|
# Create a plot with these hues in a circle |
|
|
|
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) |
|
|
|
|
|
|
|
# Each wedge in the circle |
|
|
|
theta = np.linspace(0, 2 * np.pi, N+1) + np.pi / 2 |
|
|
|
width = 2 * np.pi / (N) # equal size for each wedge |
|
|
|
|
|
|
|
for i in range(N): |
|
|
|
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0) |
|
|
|
|
|
|
|
ax.set_xticks([]) |
|
|
|
ax.set_yticks([]) |
|
|
|
ax.axis("off") |
|
|
|
fig.tight_layout() |
|
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
# name = "color_128_0.3_1.00e-06" |
|
|
|
import argparse |
|
|
|
import glob |
|
|
|
|
|
|
|
v = 29 |
|
|
|
name = f"v{v}" |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
# make the following accept a list of arguments |
|
|
|
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) |
|
|
|
args = parser.parse_args() |
|
|
|
versions = args.version |
|
|
|
for v in versions: |
|
|
|
name = f"out/v{v}" |
|
|
|
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
|
|
|
ckpt = glob.glob( |
|
|
|
f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" |
|
|
|
)[-1] |
|
|
|
make_image(ckpt, fname=name) |
|
|
|
make_image(ckpt, fname=name + "b", color=False) |
|
|
|
ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" |
|
|
|
ckpt = glob.glob(ckpt_path)[-1] |
|
|
|
print(f"Generating image for checkpoint: {ckpt}") |
|
|
|
create_circle(ckpt, fname=name) |
|
|
|
# make_image(ckpt, fname=name + "b", color=False) |
|
|
|