You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.6 KiB
50 lines
1.6 KiB
# import matplotlib.patches as patches
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
from dataloader import extract_colors, preprocess_data
|
|
from model import ColorTransformerModel
|
|
|
|
|
|
def make_image(ckpt: str, fname: str, color=True):
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt)
|
|
|
|
# preds = M(rgb_tensor)
|
|
if not color:
|
|
N = 949
|
|
linear_space = torch.linspace(0, 1, N)
|
|
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
|
|
else:
|
|
rgb_tensor, names = extract_colors()
|
|
|
|
rgb_values = rgb_tensor.detach().numpy()
|
|
rgb_tensor = preprocess_data(rgb_tensor)
|
|
preds = M(rgb_tensor)
|
|
sorted_inds = np.argsort(preds.detach().numpy().ravel())
|
|
|
|
fig, ax = plt.subplots()
|
|
for i in range(len(sorted_inds)):
|
|
idx = sorted_inds[i]
|
|
color = rgb_values[idx]
|
|
ax.plot([i + 0.5, i + 0.5], [0, 1], lw=1, c=color, antialiased=True, alpha=1)
|
|
# rect = patches.Rectangle((i, 0), 1, 5, linewidth=0.1, edgecolor=None, facecolor=None, alpha=1)
|
|
# ax.add_patch(rect)
|
|
ax.axis("off")
|
|
# ax.axis("square")
|
|
|
|
plt.savefig(f"{fname}.png", dpi=300)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# name = "color_128_0.3_1.00e-06"
|
|
import glob
|
|
|
|
v = 29
|
|
name = f"v{v}"
|
|
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
|
|
ckpt = glob.glob(
|
|
f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
|
|
)[-1]
|
|
make_image(ckpt, fname=name)
|
|
make_image(ckpt, fname=name + "b", color=False)
|
|
|