Browse Source

load umap

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
5948ab477e
  1. 57
      scripts/sortcolor.py

57
scripts/sortcolor.py

@ -85,23 +85,36 @@ if KIND in ("lex", "alpha", "abc"):
preds = np.array(colors)
elif KIND == "umap":
# from umap import UMAP
from cuml import UMAP
# Use UMAP to create a 1D representation
reducer = UMAP(
n_components=1,
n_neighbors=250,
min_dist=1e-2,
metric="euclidean",
random_state=SEED,
negative_sample_rate=2,
)
embedding = reducer.fit_transform(np.array(rgb_values))
PDIR = f"scripts/{KIND}-prod"
Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{SEED:06d}.npy"
if Path(file_path).exists():
print(f"Loading {file_path}")
preds = np.load(file_path)
else:
# from umap import UMAP
from cuml import UMAP # not fully deterministic.
# Use UMAP to create a 1D representation
reducer = UMAP(
n_components=1,
n_neighbors=250,
min_dist=1e-2,
metric="euclidean",
random_state=SEED,
negative_sample_rate=2,
)
embedding = reducer.fit_transform(np.array(rgb_values))
# Sort colors by the 1D representation
preds = embedding[:, 0]
del reducer, embedding
# Sort colors by the 1D representation
preds = embedding[:, 0]
del reducer, embedding
# Save the sorted indices to disk
print(f"Saving {file_path}")
np.save(file_path, preds.ravel())
elif KIND in ("cielab", "lab", "ciede2000"):
from skimage.color import deltaE_ciede2000, rgb2lab
@ -135,17 +148,9 @@ elif KIND == "hsv":
else:
raise ValueError(f"Unknown kind: {KIND}")
sorted_indices = np.argsort(preds)
# Save the sorted indices to disk
# if (KIND == "umap") or (KIND != "umap"):
PDIR = f"scripts/{KIND}"
Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{SEED:06d}.npy"
np.save(file_path, preds.ravel())
print(f"Predictions saved to {file_path}")
# Sort colors by the 1D representation
sorted_indices = np.argsort(preds)
sorted_colors = [colors[i] for i in sorted_indices]
# # Display the sorted colors around the ring of a circle
@ -219,8 +224,6 @@ def plot_preds(
dpi: int = 300,
figsize=(6, 6),
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:

Loading…
Cancel
Save