diff --git a/app/web/main.py b/app/web/main.py index e04b25e..9e156da 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -450,13 +450,31 @@ def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]: # --------------------------------------------------------------------------- -def embed_args_hash(embed_args: Optional[Dict[str, Any]]) -> str: - """8-hex digest of an embed_args dict (keys sorted). Stems incorporate - this so runs that differ only in embed_args get distinct output files.""" - s = json.dumps(embed_args or {}, sort_keys=True, default=str) +def run_args_hash( + embed_args: Optional[Dict[str, Any]], + generator_kwargs: Optional[Dict[str, Any]] = None, +) -> str: + """8-hex digest of (embed_args, generator_kwargs). When generator_kwargs + is empty/None we hash embed_args alone — preserves stems for the plain + generators (s_curve, plain swiss_roll) that never had gen_kwargs. For + kwargs-bearing variants (swiss_roll_hole, blobs, gaussian_quantiles, + classification), the hash now disambiguates them from their kwargs-less + siblings — run scripts/backfill_hashes.py to rehash existing figs.""" + if generator_kwargs: + payload: Any = { + "embed_args": embed_args or {}, + "generator_kwargs": generator_kwargs, + } + else: + payload = embed_args or {} + s = json.dumps(payload, sort_keys=True, default=str) return hashlib.sha1(s.encode()).hexdigest()[:8] +# Back-compat alias — some call sites passed only embed_args. +embed_args_hash = run_args_hash + + def synthesize_output_paths( generator_path: str, embedder: str, @@ -465,6 +483,7 @@ def synthesize_output_paths( jitter_scale: float, seed: int, embed_args: Optional[Dict[str, Any]] = None, + generator_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[str, str]: gen = generator_path.split(".")[-1] emb = embedder.split(".")[-1] @@ -473,7 +492,7 @@ def synthesize_output_paths( if embed_args is None: embf = f"{base}.html" else: - embf = f"{base}_{embed_args_hash(embed_args)}.html" + embf = f"{base}_{run_args_hash(embed_args, generator_kwargs)}.html" return ref, embf @@ -620,6 +639,7 @@ def _run_view(run: Dict[str, Any]) -> Dict[str, Any]: float(params.get("jitter_scale", 0.01)), int(params.get("seed", 42)), embed_args=params.get("embed_args") or {}, + generator_kwargs=params.get("generator_kwargs") or {}, ) # Older runs may lack the hash suffix; prefer legacy name on disk. emb_file = _resolve_emb_file(emb_file) @@ -788,13 +808,12 @@ async def submit(request: Request) -> HTMLResponse: embed_args = build_embed_args(reducer, data) # Reject submissions whose output path would overwrite an existing fig. - # The stem now includes an 8-hex hash of embed_args, so UMAP(n_neighbors=5) - # and UMAP(n_neighbors=15) produce distinct files. Check both the hashed - # path (new runs) and the legacy hashless path (pre-hash runs) so users - # can't accidentally duplicate against a pre-hash fig either. + # Hash now covers both embed_args and generator_kwargs, so swiss_roll vs + # swiss_roll_hole (and blobs with varying n_features, etc.) no longer + # share a stem. Also check the legacy hashless path for pre-hash figs. _, hashed_emb = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, - embed_args=embed_args, + embed_args=embed_args, generator_kwargs=generator_kwargs, ) _, legacy_emb = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, @@ -838,7 +857,7 @@ async def submit(request: Request) -> HTMLResponse: ref_file, emb_file = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, - embed_args=embed_args, + embed_args=embed_args, generator_kwargs=generator_kwargs, ) RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file} @@ -895,20 +914,61 @@ for _m in DATASET_META.values(): _GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m) +def _lookup_dataset_meta( + generator_short: str, generator_kwargs: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """Match DATASET_META by generator short-name AND kwargs when available. + Falls back to first-wins when kwargs are unknown (ambiguous for + swiss_roll vs swiss_roll_hole — both share `make_swiss_roll`).""" + candidates = [ + m for m in DATASET_META.values() + if m["path"].rsplit(".", 1)[-1] == generator_short + ] + if not candidates: + return None + if generator_kwargs is not None: + for m in candidates: + if m["kwargs"] == generator_kwargs: + return m + return candidates[0] + + def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]: """Attach per-point class/continuous labels by regenerating the dataset - with the same (generator, n_samples, kwargs). The stem's `seed` drives - jitter — NOT generator — so we always use random_state=0 to match the - flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points) - get None so the client renders them as black.""" - meta = _GEN_TO_META.get(d["meta"].get("generator") or "") - if not meta: + with the same (generator, n_samples, kwargs). random_state is fixed at 0 + (the flow's _DEFAULT_GENERATOR_KWARGS) — the stem's `seed` drives jitter, + not the generator. Jitter-added points (id >= num_points) get None so + the client renders them as black. + + Discovers generator_kwargs in priority order: (1) payload meta (sidecar + runs from the updated flow); (2) sibling metrics.json; (3) DATASET_META + by first-match (ambiguous for swiss_roll/swiss_roll_hole — need a + backfilled metrics.json to disambiguate).""" + meta = d.get("meta") or {} + gen_short = meta.get("generator") or "" + + gk = meta.get("generator_kwargs") + if gk is None: + stem = meta.get("stem") + if stem: + mx = FIGS_DIR / f"{stem}.metrics.json" + if mx.is_file(): + try: + gk = json.loads(mx.read_text(encoding="utf-8")).get( + "meta", {} + ).get("generator_kwargs") + except Exception: + gk = None + + dm = _lookup_dataset_meta(gen_short, gk) + if not dm: return d + kwargs_to_use = gk if gk is not None else dm["kwargs"] try: - mod_path, cls_name = meta["path"].rsplit(".", 1) + mod_path, cls_name = dm["path"].rsplit(".", 1) fn = getattr(importlib.import_module(mod_path), cls_name) - N = int(d["meta"]["num_points"]) - _, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"]) + N = int(meta["num_points"]) + _, gen_labels = fn(n_samples=N, random_state=0, **kwargs_to_use) out_labels: List[Optional[float]] = [] for pid in d["point_ids"]: if isinstance(pid, int) and 0 <= pid < N: @@ -917,7 +977,7 @@ def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]: else: out_labels.append(None) d["labels"] = out_labels - d["label_kind"] = meta["kind"] + d["label_kind"] = dm["kind"] except Exception: pass return d @@ -934,6 +994,10 @@ def _cached_frames(stem: str) -> str: else: html = FIGS_DIR / f"{stem}.html" d = parse_plotly_run(html) + # Override meta.stem with the URL-requested stem — after a backfill the + # file was renamed but the baked-in meta.stem still points at the old + # name. Enrichment uses this to find the sibling metrics.json. + d.setdefault("meta", {})["stem"] = stem d = _enrich_with_labels(d) return json.dumps(d, separators=(",", ":")) diff --git a/flows/embedding_flow.py b/flows/embedding_flow.py index 2d1b1e3..9ebe7a2 100644 --- a/flows/embedding_flow.py +++ b/flows/embedding_flow.py @@ -27,10 +27,19 @@ from pathlib import Path from typing import Any, Dict, List, Optional -def _embed_args_hash(ea: Optional[Dict[str, Any]]) -> str: - """8-hex digest of embed_args (keys sorted) — output stem includes this - so runs differing only in embed_args get distinct files.""" - s = json.dumps(ea or {}, sort_keys=True, default=str) +def _run_args_hash( + ea: Optional[Dict[str, Any]], + gk: Optional[Dict[str, Any]] = None, +) -> str: + """8-hex digest over (embed_args, generator_kwargs). When gk is empty we + hash embed_args alone — keeps stems stable for plain generators that + never had gen_kwargs (s_curve, plain swiss_roll). Must mirror + app.web.main.run_args_hash exactly.""" + if gk: + payload: Any = {"embed_args": ea or {}, "generator_kwargs": gk} + else: + payload = ea or {} + s = json.dumps(payload, sort_keys=True, default=str) return hashlib.sha1(s.encode()).hexdigest()[:8] @@ -45,7 +54,7 @@ def _flow_run_name() -> str: T = p.get("num_timesteps", "?") J = p.get("jitter_scale", "?") s = p.get("seed", "?") - tag = _embed_args_hash(p.get("embed_args")) + tag = _run_args_hash(p.get("embed_args"), p.get("generator_kwargs")) return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}_{tag}" from prefect import flow, runtime, task @@ -302,7 +311,7 @@ def embedding_flow( output_ref: str = ( f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" ) - _args_tag = _embed_args_hash(embed_args) + _args_tag = _run_args_hash(embed_args, generator_kwargs) output_embed: str = ( f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html" ) @@ -396,6 +405,7 @@ def embedding_flow( "jitter_scale": jitter_scale, "seed": seed, "generator_path": generator_path, + "generator_kwargs": generator_kwargs or {}, "embedder": embedder, "embed_args": merged_embed_args, }, @@ -416,6 +426,9 @@ def embedding_flow( _sys.path.insert(0, _root) from app.web.plotly_parse import parse_plotly_run frames = parse_plotly_run(emb_path_result) + # Persist generator_kwargs so the server's label enrichment can + # regenerate the correct dataset variant (swiss_roll vs hole). + frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {} Path(output_frames).write_text( json.dumps(frames, separators=(",", ":")), encoding="utf-8" ) diff --git a/scripts/backfill_hashes.py b/scripts/backfill_hashes.py index 729fc9e..0df2eeb 100644 --- a/scripts/backfill_hashes.py +++ b/scripts/backfill_hashes.py @@ -1,13 +1,18 @@ -"""Rename pre-hash embedder figs to include the embed_args hash suffix. +"""Rename embedder figs to the current hash scheme (embed_args + generator_kwargs). -Walks figs/ for `.html` files matching the old stem shape (no hash tail) that -represent an embedder run (not Reference), reads the sibling -`.metrics.json` to recover `meta.embed_args`, computes the hash, and -renames the .html + .metrics.json in place. +Two waves of runs may exist on disk: + (1) pre-hash — `.html` + (2) intermediate — `_.html` (from the first hash rollout) + (3) current — `_.html` when gen_kwargs is truthy; + identical to (2) when gen_kwargs is empty. -Default is a dry-run — pass `--apply` to actually rename. Reference files are -left alone (they have no embed_args). Missing metrics.json → warn and skip. -Target-name collision → warn and skip. +This script queries Prefect for each recent run's full params (so it knows +generator_kwargs — which the metrics.json sidecar didn't persist before), finds +the matching fig on disk, renames to the current stem, and injects +`meta.generator_kwargs` into the metrics.json so the web server's label +enrichment disambiguates swiss_roll vs swiss_roll_hole etc. + +Dry-run by default. Pass --apply to rename. Usage: .venv/bin/python scripts/backfill_hashes.py [--apply] [--figs-dir PATH] @@ -16,65 +21,91 @@ Usage: from __future__ import annotations import argparse +import asyncio +import hashlib import json -import re import sys from pathlib import Path +from typing import Any, Dict, List, Optional -# Reach up to the project root so we can reuse the canonical hash helper. _ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_ROOT)) -from app.web.main import embed_args_hash # noqa: E402 - -_LEGACY_STEM = re.compile( - r"^(?Pmake_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+)$" -) +from app.web.main import PREFECT, run_args_hash # noqa: E402 -def plan_renames(figs_dir: Path): - for html in sorted(figs_dir.glob("*.html")): - stem = html.stem - m = _LEGACY_STEM.match(stem) - if not m: - # Either already hashed or doesn't match our scheme at all. - continue - # Skip Reference runs — they have no embed_args. - if "_Reference_" in stem: - continue - metrics = figs_dir / f"{stem}.metrics.json" - if not metrics.is_file(): - yield (html, None, "missing metrics.json — can't compute hash") - continue - try: - ea = json.loads(metrics.read_text(encoding="utf-8"))["meta"]["embed_args"] - except (KeyError, json.JSONDecodeError) as e: - yield (html, None, f"bad metrics.json: {e}") - continue - new_stem = f"{stem}_{embed_args_hash(ea)}" - new_html = figs_dir / f"{new_stem}.html" - if new_html.exists(): - yield (html, None, f"target exists: {new_html.name}") - continue - yield (html, new_stem, None) +def _legacy_hash(ea: Optional[Dict[str, Any]]) -> str: + s = json.dumps(ea or {}, sort_keys=True, default=str) + return hashlib.sha1(s.encode()).hexdigest()[:8] -def apply_rename(figs_dir: Path, old_stem: str, new_stem: str) -> list[str]: - """Rename every sidecar sharing the old stem. Returns the renamed files.""" - renamed = [] +def _base_stem(params: Dict[str, Any]) -> Optional[str]: + try: + gen = (params.get("generator_path") or "").rsplit(".", 1)[-1] + emb = (params.get("embedder") or "").rsplit(".", 1)[-1] + N = int(params["num_points"]) + T = int(params.get("num_timesteps", params.get("num_snapshots"))) + J = float(params["jitter_scale"]) + s = int(params["seed"]) + except (KeyError, TypeError, ValueError): + return None + if not gen or not emb: + return None + return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}" + + +def _candidate_names(base: str, ea: Dict[str, Any], gk: Dict[str, Any]) -> List[str]: + target = f"{base}_{run_args_hash(ea, gk)}.html" + legacy = f"{base}_{_legacy_hash(ea)}.html" + no_hash = f"{base}.html" + # Preserve order: target first so we short-circuit on already-backfilled. + out = [target] + for x in (legacy, no_hash): + if x not in out: + out.append(x) + return out + + +def _patch_metrics(path: Path, gk: Dict[str, Any]) -> bool: + if not path.is_file(): + return False + try: + d = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return False + meta = d.setdefault("meta", {}) + if meta.get("generator_kwargs") == gk: + return False + meta["generator_kwargs"] = gk + path.write_text(json.dumps(d, indent=2), encoding="utf-8") + return True + + +def _rename_bundle(figs_dir: Path, old_stem: str, new_stem: str) -> List[str]: + moved = [] for suffix in (".html", ".metrics.json", ".frames.json"): src = figs_dir / f"{old_stem}{suffix}" if not src.exists(): continue dst = figs_dir / f"{new_stem}{suffix}" + if dst.exists(): + moved.append(f"SKIP (target exists) {src.name}") + continue src.rename(dst) - renamed.append(f"{src.name} -> {dst.name}") - return renamed + moved.append(f"{src.name} -> {dst.name}") + return moved + + +async def _fetch_runs(limit: int = 200) -> List[Dict[str, Any]]: + import httpx + async with httpx.AsyncClient(timeout=10.0) as c: + return await PREFECT.recent_runs(c, limit=limit) def main() -> int: ap = argparse.ArgumentParser(description=__doc__) - ap.add_argument("--apply", action="store_true", help="actually rename (default: dry-run)") + ap.add_argument("--apply", action="store_true", help="actually rename + patch (default: dry-run)") ap.add_argument("--figs-dir", default=str(_ROOT / "figs"), help="path to figs/ directory") + ap.add_argument("--limit", type=int, default=200, help="Prefect runs to scan") args = ap.parse_args() figs_dir = Path(args.figs_dir).resolve() @@ -82,36 +113,65 @@ def main() -> int: print(f"no such directory: {figs_dir}", file=sys.stderr) return 2 - planned, skipped = [], [] - for html, new_stem, reason in plan_renames(figs_dir): - if new_stem is None: - skipped.append((html.name, reason)) - else: - planned.append((html.stem, new_stem)) + try: + runs = asyncio.run(_fetch_runs(limit=args.limit)) + except Exception as e: + print(f"could not reach Prefect at {PREFECT.base} ({e})", file=sys.stderr) + return 3 - print(f"scanning {figs_dir}") - print(f" {len(planned)} to rename, {len(skipped)} skipped\n") + plans = [] # (old_stem, new_stem, gk, found_name) + seen_targets = set() + for r in runs: + params = r.get("parameters") or {} + ea = params.get("embed_args") or {} + gk = params.get("generator_kwargs") or {} + base = _base_stem(params) + if not base: + continue + target = f"{base}_{run_args_hash(ea, gk)}.html" + if target in seen_targets: + continue # later duplicate — the stale-marking logic will handle it + for candidate in _candidate_names(base, ea, gk): + if (figs_dir / candidate).exists(): + if candidate == target: + # Already at target; just ensure metrics.json carries gk. + plans.append((Path(candidate).stem, Path(target).stem, gk, candidate, True)) + else: + plans.append((Path(candidate).stem, Path(target).stem, gk, candidate, False)) + seen_targets.add(target) + break - for old, new in planned: - print(f" rename {old} -> {new}") - if skipped: - print("\n skipped:") - for name, reason in skipped: - print(f" {name} ({reason})") + print(f"scanning {figs_dir} (Prefect runs seen: {len(runs)})") + renames = [p for p in plans if not p[4]] + already = [p for p in plans if p[4]] + print(f" {len(renames)} to rename, {len(already)} already at target\n") - if not planned: + for old, new, gk, _, _ in renames: + gk_str = json.dumps(gk) if gk else "{}" + print(f" rename {old} -> {new} gen_kwargs={gk_str}") + + if already: + print(f"\n at-target (will only patch metrics.json if missing gen_kwargs):") + for old, _, gk, name, _ in already: + print(f" {name} gen_kwargs={json.dumps(gk) if gk else '{}'}") + + if not renames and not already: + print("nothing to do") return 0 if not args.apply: - print("\n(dry run — pass --apply to rename)") + print("\n(dry run — pass --apply to rename + patch)") return 0 print("\napplying...") - for old, new in planned: - moved = apply_rename(figs_dir, old, new) - for line in moved: - print(f" {line}") - print(f"done — renamed {len(planned)} run(s)") + for old, new, gk, _, at_target in plans: + if not at_target: + for line in _rename_bundle(figs_dir, old, new): + print(f" {line}") + patched = _patch_metrics(figs_dir / f"{new}.metrics.json", gk) + if patched: + print(f" patched {new}.metrics.json (generator_kwargs)") + print(f"done — renamed {len(renames)}, patched metrics for {len(plans)} run(s)") return 0