You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

41 lines
1.3 KiB

import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots(figsize=(20, 5))
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.plot([i, i],[0, 5], lw=0.5, c=color, antialiased=False, alpha=1)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", dpi=300)
if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
name = "color_64_1_1.0e-3.png"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt = "/teamspace/studios/this_studio/colors/lightning_logs/version_26/checkpoints/epoch=99-step=1500.ckpt"
make_image(ckpt, fname=name)