|
@ -85,8 +85,17 @@ if KIND in ("lex", "alpha", "abc"): |
|
|
preds = np.array(colors) |
|
|
preds = np.array(colors) |
|
|
|
|
|
|
|
|
elif KIND == "umap": |
|
|
elif KIND == "umap": |
|
|
|
|
|
PDIR = f"scripts/{KIND}-prod" |
|
|
|
|
|
Path(PDIR).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
file_path = f"{PDIR}/{SEED:06d}.npy" |
|
|
|
|
|
|
|
|
|
|
|
if Path(file_path).exists(): |
|
|
|
|
|
print(f"Loading {file_path}") |
|
|
|
|
|
preds = np.load(file_path) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
# from umap import UMAP |
|
|
# from umap import UMAP |
|
|
from cuml import UMAP |
|
|
from cuml import UMAP # not fully deterministic. |
|
|
|
|
|
|
|
|
# Use UMAP to create a 1D representation |
|
|
# Use UMAP to create a 1D representation |
|
|
reducer = UMAP( |
|
|
reducer = UMAP( |
|
@ -103,6 +112,10 @@ elif KIND == "umap": |
|
|
preds = embedding[:, 0] |
|
|
preds = embedding[:, 0] |
|
|
del reducer, embedding |
|
|
del reducer, embedding |
|
|
|
|
|
|
|
|
|
|
|
# Save the sorted indices to disk |
|
|
|
|
|
print(f"Saving {file_path}") |
|
|
|
|
|
np.save(file_path, preds.ravel()) |
|
|
|
|
|
|
|
|
elif KIND in ("cielab", "lab", "ciede2000"): |
|
|
elif KIND in ("cielab", "lab", "ciede2000"): |
|
|
from skimage.color import deltaE_ciede2000, rgb2lab |
|
|
from skimage.color import deltaE_ciede2000, rgb2lab |
|
|
|
|
|
|
|
@ -135,17 +148,9 @@ elif KIND == "hsv": |
|
|
else: |
|
|
else: |
|
|
raise ValueError(f"Unknown kind: {KIND}") |
|
|
raise ValueError(f"Unknown kind: {KIND}") |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(preds) |
|
|
|
|
|
|
|
|
|
|
|
# Save the sorted indices to disk |
|
|
|
|
|
# if (KIND == "umap") or (KIND != "umap"): |
|
|
|
|
|
PDIR = f"scripts/{KIND}" |
|
|
|
|
|
Path(PDIR).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
file_path = f"{PDIR}/{SEED:06d}.npy" |
|
|
|
|
|
np.save(file_path, preds.ravel()) |
|
|
|
|
|
print(f"Predictions saved to {file_path}") |
|
|
|
|
|
|
|
|
|
|
|
# Sort colors by the 1D representation |
|
|
# Sort colors by the 1D representation |
|
|
|
|
|
sorted_indices = np.argsort(preds) |
|
|
sorted_colors = [colors[i] for i in sorted_indices] |
|
|
sorted_colors = [colors[i] for i in sorted_indices] |
|
|
|
|
|
|
|
|
# # Display the sorted colors around the ring of a circle |
|
|
# # Display the sorted colors around the ring of a circle |
|
@ -219,8 +224,6 @@ def plot_preds( |
|
|
dpi: int = 300, |
|
|
dpi: int = 300, |
|
|
figsize=(6, 6), |
|
|
figsize=(6, 6), |
|
|
): |
|
|
): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
|
|
|
preds = preds.detach().cpu().numpy() |
|
|
|
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
colors = rgb_values[sorted_inds, :3] |
|
|
colors = rgb_values[sorted_inds, :3] |
|
|
if roll: |
|
|
if roll: |
|
|