import matplotlib.pyplot as plt import numpy as np import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel def make_image(ckpt: str, fname: str, color=True): M = ColorTransformerModel.load_from_checkpoint(ckpt) # preds = M(rgb_tensor) if not color: N = 949 linear_space = torch.linspace(0, 1, N) rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) else: rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor) sorted_inds = np.argsort(preds.detach().numpy().ravel()) fig, ax = plt.subplots(figsize=(20, 5)) for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] ax.plot([i, i],[0, 5], lw=0.5, c=color, antialiased=False, alpha=1) ax.axis("off") # ax.axis("square") plt.savefig(f"{fname}.png", dpi=300) if __name__ == "__main__": # name = "color_128_0.3_1.00e-06" name = "color_64_1_1.0e-3.png" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" ckpt = "/teamspace/studios/this_studio/colors/lightning_logs/version_26/checkpoints/epoch=99-step=1500.ckpt" make_image(ckpt, fname=name)