stems: fold generator_kwargs into the hash; fix swiss_roll vs hole ambiguity

- run_args_hash now covers (embed_args, generator_kwargs). When gen_kwargs
  is empty we still hash embed_args alone — so plain generators (s_curve,
  plain swiss_roll) keep their stems and no existing plain-gen figs need
  renaming. Kwargs-bearing variants (swiss_roll_hole, blobs,
  gaussian_quantiles, classification) now disambiguate properly.
- Flow persists generator_kwargs into metrics.json meta AND into the
  frames.json sidecar meta, so the label-enrichment path can find it
  without another lookup.
- _enrich_with_labels discovers gen_kwargs in priority: payload meta -->
  sibling metrics.json --> DATASET_META first-match. It matches the
  DATASET_META entry by (path, kwargs) so swiss_roll_hole is no longer
  confused for plain swiss_roll.
- _cached_frames overrides meta.stem with the URL-requested stem before
  enrichment — after a backfill rename the sidecar's baked-in stem is
  stale, and we were then failing to find the sibling metrics.json.
- Submit duplicate-check uses the new hash and keeps the hashless-legacy
  check as a safety net.
- backfill_hashes.py rewritten: queries Prefect for each recent run's
  full params, finds the matching fig under any of (current, legacy,
  hashless) names, renames to the current scheme and patches
  generator_kwargs into metrics.json.
This commit is contained in:
Michael Pilosov 2026-04-22 16:30:42 -06:00
parent 44de8deeeb
commit b744c48348
3 changed files with 231 additions and 94 deletions

View File

@ -450,13 +450,31 @@ def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def embed_args_hash(embed_args: Optional[Dict[str, Any]]) -> str: def run_args_hash(
"""8-hex digest of an embed_args dict (keys sorted). Stems incorporate embed_args: Optional[Dict[str, Any]],
this so runs that differ only in embed_args get distinct output files.""" generator_kwargs: Optional[Dict[str, Any]] = None,
s = json.dumps(embed_args or {}, sort_keys=True, default=str) ) -> str:
"""8-hex digest of (embed_args, generator_kwargs). When generator_kwargs
is empty/None we hash embed_args alone preserves stems for the plain
generators (s_curve, plain swiss_roll) that never had gen_kwargs. For
kwargs-bearing variants (swiss_roll_hole, blobs, gaussian_quantiles,
classification), the hash now disambiguates them from their kwargs-less
siblings run scripts/backfill_hashes.py to rehash existing figs."""
if generator_kwargs:
payload: Any = {
"embed_args": embed_args or {},
"generator_kwargs": generator_kwargs,
}
else:
payload = embed_args or {}
s = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8] return hashlib.sha1(s.encode()).hexdigest()[:8]
# Back-compat alias — some call sites passed only embed_args.
embed_args_hash = run_args_hash
def synthesize_output_paths( def synthesize_output_paths(
generator_path: str, generator_path: str,
embedder: str, embedder: str,
@ -465,6 +483,7 @@ def synthesize_output_paths(
jitter_scale: float, jitter_scale: float,
seed: int, seed: int,
embed_args: Optional[Dict[str, Any]] = None, embed_args: Optional[Dict[str, Any]] = None,
generator_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
gen = generator_path.split(".")[-1] gen = generator_path.split(".")[-1]
emb = embedder.split(".")[-1] emb = embedder.split(".")[-1]
@ -473,7 +492,7 @@ def synthesize_output_paths(
if embed_args is None: if embed_args is None:
embf = f"{base}.html" embf = f"{base}.html"
else: else:
embf = f"{base}_{embed_args_hash(embed_args)}.html" embf = f"{base}_{run_args_hash(embed_args, generator_kwargs)}.html"
return ref, embf return ref, embf
@ -620,6 +639,7 @@ def _run_view(run: Dict[str, Any]) -> Dict[str, Any]:
float(params.get("jitter_scale", 0.01)), float(params.get("jitter_scale", 0.01)),
int(params.get("seed", 42)), int(params.get("seed", 42)),
embed_args=params.get("embed_args") or {}, embed_args=params.get("embed_args") or {},
generator_kwargs=params.get("generator_kwargs") or {},
) )
# Older runs may lack the hash suffix; prefer legacy name on disk. # Older runs may lack the hash suffix; prefer legacy name on disk.
emb_file = _resolve_emb_file(emb_file) emb_file = _resolve_emb_file(emb_file)
@ -788,13 +808,12 @@ async def submit(request: Request) -> HTMLResponse:
embed_args = build_embed_args(reducer, data) embed_args = build_embed_args(reducer, data)
# Reject submissions whose output path would overwrite an existing fig. # Reject submissions whose output path would overwrite an existing fig.
# The stem now includes an 8-hex hash of embed_args, so UMAP(n_neighbors=5) # Hash now covers both embed_args and generator_kwargs, so swiss_roll vs
# and UMAP(n_neighbors=15) produce distinct files. Check both the hashed # swiss_roll_hole (and blobs with varying n_features, etc.) no longer
# path (new runs) and the legacy hashless path (pre-hash runs) so users # share a stem. Also check the legacy hashless path for pre-hash figs.
# can't accidentally duplicate against a pre-hash fig either.
_, hashed_emb = synthesize_output_paths( _, hashed_emb = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
embed_args=embed_args, embed_args=embed_args, generator_kwargs=generator_kwargs,
) )
_, legacy_emb = synthesize_output_paths( _, legacy_emb = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
@ -838,7 +857,7 @@ async def submit(request: Request) -> HTMLResponse:
ref_file, emb_file = synthesize_output_paths( ref_file, emb_file = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
embed_args=embed_args, embed_args=embed_args, generator_kwargs=generator_kwargs,
) )
RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file} RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file}
@ -895,20 +914,61 @@ for _m in DATASET_META.values():
_GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m) _GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m)
def _lookup_dataset_meta(
generator_short: str, generator_kwargs: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Match DATASET_META by generator short-name AND kwargs when available.
Falls back to first-wins when kwargs are unknown (ambiguous for
swiss_roll vs swiss_roll_hole both share `make_swiss_roll`)."""
candidates = [
m for m in DATASET_META.values()
if m["path"].rsplit(".", 1)[-1] == generator_short
]
if not candidates:
return None
if generator_kwargs is not None:
for m in candidates:
if m["kwargs"] == generator_kwargs:
return m
return candidates[0]
def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]: def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
"""Attach per-point class/continuous labels by regenerating the dataset """Attach per-point class/continuous labels by regenerating the dataset
with the same (generator, n_samples, kwargs). The stem's `seed` drives with the same (generator, n_samples, kwargs). random_state is fixed at 0
jitter NOT generator so we always use random_state=0 to match the (the flow's _DEFAULT_GENERATOR_KWARGS) — the stem's `seed` drives jitter,
flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points) not the generator. Jitter-added points (id >= num_points) get None so
get None so the client renders them as black.""" the client renders them as black.
meta = _GEN_TO_META.get(d["meta"].get("generator") or "")
if not meta: Discovers generator_kwargs in priority order: (1) payload meta (sidecar
return d runs from the updated flow); (2) sibling metrics.json; (3) DATASET_META
by first-match (ambiguous for swiss_roll/swiss_roll_hole need a
backfilled metrics.json to disambiguate)."""
meta = d.get("meta") or {}
gen_short = meta.get("generator") or ""
gk = meta.get("generator_kwargs")
if gk is None:
stem = meta.get("stem")
if stem:
mx = FIGS_DIR / f"{stem}.metrics.json"
if mx.is_file():
try: try:
mod_path, cls_name = meta["path"].rsplit(".", 1) gk = json.loads(mx.read_text(encoding="utf-8")).get(
"meta", {}
).get("generator_kwargs")
except Exception:
gk = None
dm = _lookup_dataset_meta(gen_short, gk)
if not dm:
return d
kwargs_to_use = gk if gk is not None else dm["kwargs"]
try:
mod_path, cls_name = dm["path"].rsplit(".", 1)
fn = getattr(importlib.import_module(mod_path), cls_name) fn = getattr(importlib.import_module(mod_path), cls_name)
N = int(d["meta"]["num_points"]) N = int(meta["num_points"])
_, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"]) _, gen_labels = fn(n_samples=N, random_state=0, **kwargs_to_use)
out_labels: List[Optional[float]] = [] out_labels: List[Optional[float]] = []
for pid in d["point_ids"]: for pid in d["point_ids"]:
if isinstance(pid, int) and 0 <= pid < N: if isinstance(pid, int) and 0 <= pid < N:
@ -917,7 +977,7 @@ def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
else: else:
out_labels.append(None) out_labels.append(None)
d["labels"] = out_labels d["labels"] = out_labels
d["label_kind"] = meta["kind"] d["label_kind"] = dm["kind"]
except Exception: except Exception:
pass pass
return d return d
@ -934,6 +994,10 @@ def _cached_frames(stem: str) -> str:
else: else:
html = FIGS_DIR / f"{stem}.html" html = FIGS_DIR / f"{stem}.html"
d = parse_plotly_run(html) d = parse_plotly_run(html)
# Override meta.stem with the URL-requested stem — after a backfill the
# file was renamed but the baked-in meta.stem still points at the old
# name. Enrichment uses this to find the sibling metrics.json.
d.setdefault("meta", {})["stem"] = stem
d = _enrich_with_labels(d) d = _enrich_with_labels(d)
return json.dumps(d, separators=(",", ":")) return json.dumps(d, separators=(",", ":"))

View File

@ -27,10 +27,19 @@ from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
def _embed_args_hash(ea: Optional[Dict[str, Any]]) -> str: def _run_args_hash(
"""8-hex digest of embed_args (keys sorted) — output stem includes this ea: Optional[Dict[str, Any]],
so runs differing only in embed_args get distinct files.""" gk: Optional[Dict[str, Any]] = None,
s = json.dumps(ea or {}, sort_keys=True, default=str) ) -> str:
"""8-hex digest over (embed_args, generator_kwargs). When gk is empty we
hash embed_args alone keeps stems stable for plain generators that
never had gen_kwargs (s_curve, plain swiss_roll). Must mirror
app.web.main.run_args_hash exactly."""
if gk:
payload: Any = {"embed_args": ea or {}, "generator_kwargs": gk}
else:
payload = ea or {}
s = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8] return hashlib.sha1(s.encode()).hexdigest()[:8]
@ -45,7 +54,7 @@ def _flow_run_name() -> str:
T = p.get("num_timesteps", "?") T = p.get("num_timesteps", "?")
J = p.get("jitter_scale", "?") J = p.get("jitter_scale", "?")
s = p.get("seed", "?") s = p.get("seed", "?")
tag = _embed_args_hash(p.get("embed_args")) tag = _run_args_hash(p.get("embed_args"), p.get("generator_kwargs"))
return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}_{tag}" return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}_{tag}"
from prefect import flow, runtime, task from prefect import flow, runtime, task
@ -302,7 +311,7 @@ def embedding_flow(
output_ref: str = ( output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
) )
_args_tag = _embed_args_hash(embed_args) _args_tag = _run_args_hash(embed_args, generator_kwargs)
output_embed: str = ( output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html" f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html"
) )
@ -396,6 +405,7 @@ def embedding_flow(
"jitter_scale": jitter_scale, "jitter_scale": jitter_scale,
"seed": seed, "seed": seed,
"generator_path": generator_path, "generator_path": generator_path,
"generator_kwargs": generator_kwargs or {},
"embedder": embedder, "embedder": embedder,
"embed_args": merged_embed_args, "embed_args": merged_embed_args,
}, },
@ -416,6 +426,9 @@ def embedding_flow(
_sys.path.insert(0, _root) _sys.path.insert(0, _root)
from app.web.plotly_parse import parse_plotly_run from app.web.plotly_parse import parse_plotly_run
frames = parse_plotly_run(emb_path_result) frames = parse_plotly_run(emb_path_result)
# Persist generator_kwargs so the server's label enrichment can
# regenerate the correct dataset variant (swiss_roll vs hole).
frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {}
Path(output_frames).write_text( Path(output_frames).write_text(
json.dumps(frames, separators=(",", ":")), encoding="utf-8" json.dumps(frames, separators=(",", ":")), encoding="utf-8"
) )

View File

@ -1,13 +1,18 @@
"""Rename pre-hash embedder figs to include the embed_args hash suffix. """Rename embedder figs to the current hash scheme (embed_args + generator_kwargs).
Walks figs/ for `.html` files matching the old stem shape (no hash tail) that Two waves of runs may exist on disk:
represent an embedder run (not Reference), reads the sibling (1) pre-hash `<stem>.html`
`<stem>.metrics.json` to recover `meta.embed_args`, computes the hash, and (2) intermediate `<stem>_<sha1(embed_args)>.html` (from the first hash rollout)
renames the .html + .metrics.json in place. (3) current `<stem>_<sha1(embed_args, gen_kwargs)>.html` when gen_kwargs is truthy;
identical to (2) when gen_kwargs is empty.
Default is a dry-run pass `--apply` to actually rename. Reference files are This script queries Prefect for each recent run's full params (so it knows
left alone (they have no embed_args). Missing metrics.json warn and skip. generator_kwargs which the metrics.json sidecar didn't persist before), finds
Target-name collision warn and skip. the matching fig on disk, renames to the current stem, and injects
`meta.generator_kwargs` into the metrics.json so the web server's label
enrichment disambiguates swiss_roll vs swiss_roll_hole etc.
Dry-run by default. Pass --apply to rename.
Usage: Usage:
.venv/bin/python scripts/backfill_hashes.py [--apply] [--figs-dir PATH] .venv/bin/python scripts/backfill_hashes.py [--apply] [--figs-dir PATH]
@ -16,65 +21,91 @@ Usage:
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import asyncio
import hashlib
import json import json
import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional
# Reach up to the project root so we can reuse the canonical hash helper.
_ROOT = Path(__file__).resolve().parent.parent _ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_ROOT)) sys.path.insert(0, str(_ROOT))
from app.web.main import embed_args_hash # noqa: E402 from app.web.main import PREFECT, run_args_hash # noqa: E402
_LEGACY_STEM = re.compile(
r"^(?P<base>make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+)$"
)
def plan_renames(figs_dir: Path): def _legacy_hash(ea: Optional[Dict[str, Any]]) -> str:
for html in sorted(figs_dir.glob("*.html")): s = json.dumps(ea or {}, sort_keys=True, default=str)
stem = html.stem return hashlib.sha1(s.encode()).hexdigest()[:8]
m = _LEGACY_STEM.match(stem)
if not m:
# Either already hashed or doesn't match our scheme at all. def _base_stem(params: Dict[str, Any]) -> Optional[str]:
continue
# Skip Reference runs — they have no embed_args.
if "_Reference_" in stem:
continue
metrics = figs_dir / f"{stem}.metrics.json"
if not metrics.is_file():
yield (html, None, "missing metrics.json — can't compute hash")
continue
try: try:
ea = json.loads(metrics.read_text(encoding="utf-8"))["meta"]["embed_args"] gen = (params.get("generator_path") or "").rsplit(".", 1)[-1]
except (KeyError, json.JSONDecodeError) as e: emb = (params.get("embedder") or "").rsplit(".", 1)[-1]
yield (html, None, f"bad metrics.json: {e}") N = int(params["num_points"])
continue T = int(params.get("num_timesteps", params.get("num_snapshots")))
new_stem = f"{stem}_{embed_args_hash(ea)}" J = float(params["jitter_scale"])
new_html = figs_dir / f"{new_stem}.html" s = int(params["seed"])
if new_html.exists(): except (KeyError, TypeError, ValueError):
yield (html, None, f"target exists: {new_html.name}") return None
continue if not gen or not emb:
yield (html, new_stem, None) return None
return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}"
def apply_rename(figs_dir: Path, old_stem: str, new_stem: str) -> list[str]: def _candidate_names(base: str, ea: Dict[str, Any], gk: Dict[str, Any]) -> List[str]:
"""Rename every sidecar sharing the old stem. Returns the renamed files.""" target = f"{base}_{run_args_hash(ea, gk)}.html"
renamed = [] legacy = f"{base}_{_legacy_hash(ea)}.html"
no_hash = f"{base}.html"
# Preserve order: target first so we short-circuit on already-backfilled.
out = [target]
for x in (legacy, no_hash):
if x not in out:
out.append(x)
return out
def _patch_metrics(path: Path, gk: Dict[str, Any]) -> bool:
if not path.is_file():
return False
try:
d = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return False
meta = d.setdefault("meta", {})
if meta.get("generator_kwargs") == gk:
return False
meta["generator_kwargs"] = gk
path.write_text(json.dumps(d, indent=2), encoding="utf-8")
return True
def _rename_bundle(figs_dir: Path, old_stem: str, new_stem: str) -> List[str]:
moved = []
for suffix in (".html", ".metrics.json", ".frames.json"): for suffix in (".html", ".metrics.json", ".frames.json"):
src = figs_dir / f"{old_stem}{suffix}" src = figs_dir / f"{old_stem}{suffix}"
if not src.exists(): if not src.exists():
continue continue
dst = figs_dir / f"{new_stem}{suffix}" dst = figs_dir / f"{new_stem}{suffix}"
if dst.exists():
moved.append(f"SKIP (target exists) {src.name}")
continue
src.rename(dst) src.rename(dst)
renamed.append(f"{src.name} -> {dst.name}") moved.append(f"{src.name} -> {dst.name}")
return renamed return moved
async def _fetch_runs(limit: int = 200) -> List[Dict[str, Any]]:
import httpx
async with httpx.AsyncClient(timeout=10.0) as c:
return await PREFECT.recent_runs(c, limit=limit)
def main() -> int: def main() -> int:
ap = argparse.ArgumentParser(description=__doc__) ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--apply", action="store_true", help="actually rename (default: dry-run)") ap.add_argument("--apply", action="store_true", help="actually rename + patch (default: dry-run)")
ap.add_argument("--figs-dir", default=str(_ROOT / "figs"), help="path to figs/ directory") ap.add_argument("--figs-dir", default=str(_ROOT / "figs"), help="path to figs/ directory")
ap.add_argument("--limit", type=int, default=200, help="Prefect runs to scan")
args = ap.parse_args() args = ap.parse_args()
figs_dir = Path(args.figs_dir).resolve() figs_dir = Path(args.figs_dir).resolve()
@ -82,36 +113,65 @@ def main() -> int:
print(f"no such directory: {figs_dir}", file=sys.stderr) print(f"no such directory: {figs_dir}", file=sys.stderr)
return 2 return 2
planned, skipped = [], [] try:
for html, new_stem, reason in plan_renames(figs_dir): runs = asyncio.run(_fetch_runs(limit=args.limit))
if new_stem is None: except Exception as e:
skipped.append((html.name, reason)) print(f"could not reach Prefect at {PREFECT.base} ({e})", file=sys.stderr)
return 3
plans = [] # (old_stem, new_stem, gk, found_name)
seen_targets = set()
for r in runs:
params = r.get("parameters") or {}
ea = params.get("embed_args") or {}
gk = params.get("generator_kwargs") or {}
base = _base_stem(params)
if not base:
continue
target = f"{base}_{run_args_hash(ea, gk)}.html"
if target in seen_targets:
continue # later duplicate — the stale-marking logic will handle it
for candidate in _candidate_names(base, ea, gk):
if (figs_dir / candidate).exists():
if candidate == target:
# Already at target; just ensure metrics.json carries gk.
plans.append((Path(candidate).stem, Path(target).stem, gk, candidate, True))
else: else:
planned.append((html.stem, new_stem)) plans.append((Path(candidate).stem, Path(target).stem, gk, candidate, False))
seen_targets.add(target)
break
print(f"scanning {figs_dir}") print(f"scanning {figs_dir} (Prefect runs seen: {len(runs)})")
print(f" {len(planned)} to rename, {len(skipped)} skipped\n") renames = [p for p in plans if not p[4]]
already = [p for p in plans if p[4]]
print(f" {len(renames)} to rename, {len(already)} already at target\n")
for old, new in planned: for old, new, gk, _, _ in renames:
print(f" rename {old} -> {new}") gk_str = json.dumps(gk) if gk else "{}"
if skipped: print(f" rename {old} -> {new} gen_kwargs={gk_str}")
print("\n skipped:")
for name, reason in skipped:
print(f" {name} ({reason})")
if not planned: if already:
print(f"\n at-target (will only patch metrics.json if missing gen_kwargs):")
for old, _, gk, name, _ in already:
print(f" {name} gen_kwargs={json.dumps(gk) if gk else '{}'}")
if not renames and not already:
print("nothing to do")
return 0 return 0
if not args.apply: if not args.apply:
print("\n(dry run — pass --apply to rename)") print("\n(dry run — pass --apply to rename + patch)")
return 0 return 0
print("\napplying...") print("\napplying...")
for old, new in planned: for old, new, gk, _, at_target in plans:
moved = apply_rename(figs_dir, old, new) if not at_target:
for line in moved: for line in _rename_bundle(figs_dir, old, new):
print(f" {line}") print(f" {line}")
print(f"done — renamed {len(planned)} run(s)") patched = _patch_metrics(figs_dir / f"{new}.metrics.json", gk)
if patched:
print(f" patched {new}.metrics.json (generator_kwargs)")
print(f"done — renamed {len(renames)}, patched metrics for {len(plans)} run(s)")
return 0 return 0