diff --git a/scripts/sortcolor.py b/scripts/sortcolor.py index 6d48c76..9ebf675 100644 --- a/scripts/sortcolor.py +++ b/scripts/sortcolor.py @@ -85,23 +85,36 @@ if KIND in ("lex", "alpha", "abc"): preds = np.array(colors) elif KIND == "umap": - # from umap import UMAP - from cuml import UMAP - - # Use UMAP to create a 1D representation - reducer = UMAP( - n_components=1, - n_neighbors=250, - min_dist=1e-2, - metric="euclidean", - random_state=SEED, - negative_sample_rate=2, - ) - embedding = reducer.fit_transform(np.array(rgb_values)) + PDIR = f"scripts/{KIND}-prod" + Path(PDIR).mkdir(parents=True, exist_ok=True) + file_path = f"{PDIR}/{SEED:06d}.npy" + + if Path(file_path).exists(): + print(f"Loading {file_path}") + preds = np.load(file_path) + + else: + # from umap import UMAP + from cuml import UMAP # not fully deterministic. + + # Use UMAP to create a 1D representation + reducer = UMAP( + n_components=1, + n_neighbors=250, + min_dist=1e-2, + metric="euclidean", + random_state=SEED, + negative_sample_rate=2, + ) + embedding = reducer.fit_transform(np.array(rgb_values)) + + # Sort colors by the 1D representation + preds = embedding[:, 0] + del reducer, embedding - # Sort colors by the 1D representation - preds = embedding[:, 0] - del reducer, embedding + # Save the sorted indices to disk + print(f"Saving {file_path}") + np.save(file_path, preds.ravel()) elif KIND in ("cielab", "lab", "ciede2000"): from skimage.color import deltaE_ciede2000, rgb2lab @@ -135,17 +148,9 @@ elif KIND == "hsv": else: raise ValueError(f"Unknown kind: {KIND}") -sorted_indices = np.argsort(preds) - -# Save the sorted indices to disk -# if (KIND == "umap") or (KIND != "umap"): -PDIR = f"scripts/{KIND}" -Path(PDIR).mkdir(parents=True, exist_ok=True) -file_path = f"{PDIR}/{SEED:06d}.npy" -np.save(file_path, preds.ravel()) -print(f"Predictions saved to {file_path}") # Sort colors by the 1D representation +sorted_indices = np.argsort(preds) sorted_colors = [colors[i] for i in sorted_indices] # # Display the sorted colors around the ring of a circle @@ -219,8 +224,6 @@ def plot_preds( dpi: int = 300, figsize=(6, 6), ): - if isinstance(preds, torch.Tensor): - preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: