import matplotlib.pyplot as plt import numpy as np from dataloader import extract_colors from model import ColorTransformerModel name = "color_128_0.3_1.00e-06" ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" M = ColorTransformerModel.load_from_checkpoint(ckpt) rgb_tensor, names = extract_colors() preds = M(rgb_tensor) rgb_values = rgb_tensor.detach().numpy() sorted_inds = np.argsort(preds.detach().numpy().ravel()) fig, ax = plt.subplots(figsize=(10, 5)) for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx]) ax.axis("off") # ax.axis("square") plt.savefig(f"{name}.png", dpi=300)