You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
752 B
26 lines
752 B
11 months ago
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
from dataloader import extract_colors
|
||
|
from model import ColorTransformerModel
|
||
|
|
||
|
name = "color_128_0.3_1.00e-06"
|
||
|
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
|
||
|
M = ColorTransformerModel.load_from_checkpoint(ckpt)
|
||
|
|
||
|
rgb_tensor, names = extract_colors()
|
||
|
preds = M(rgb_tensor)
|
||
|
rgb_values = rgb_tensor.detach().numpy()
|
||
|
sorted_inds = np.argsort(preds.detach().numpy().ravel())
|
||
|
|
||
|
|
||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||
|
for i in range(len(sorted_inds)):
|
||
|
idx = sorted_inds[i]
|
||
|
color = rgb_values[idx]
|
||
|
ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx])
|
||
|
ax.axis("off")
|
||
|
# ax.axis("square")
|
||
|
|
||
|
plt.savefig(f"{name}.png", dpi=300)
|