""" web1 — "Scientific instrument / research notebook" A FastAPI UI for kicking off the embedding-flow Prefect deployment and viewing the resulting HTML animations. Design: restrained, typography-driven, two-column notebook layout. No CSS framework; hand-written styles. """ from __future__ import annotations import hashlib import importlib.util import json import os import re from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from app.web.plotly_parse import parse_plotly_run import httpx from fastapi import FastAPI, Form, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse, Response from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sklearn.datasets import ( make_blobs, make_classification, make_gaussian_quantiles, make_s_curve, make_swiss_roll, ) # --------------------------------------------------------------------------- # Paths / constants # --------------------------------------------------------------------------- BASE_DIR = Path(__file__).resolve().parent PROJECT_ROOT = BASE_DIR.parent.parent # /home/mm/work/dr-sandbox FIGS_DIR = PROJECT_ROOT / "figs" FIGS_DIR.mkdir(parents=True, exist_ok=True) PREFECT_API = os.environ.get("PREFECT_API_URL", "http://localhost:4200/api") DEPLOYMENT_NAME = "embedding-flow/embedding-flow" # --------------------------------------------------------------------------- # Dataset catalogue # --------------------------------------------------------------------------- # Metadata for the /data.json endpoint consumed by the dataset picker, and # for server-side lookup when the picker posts its selection back. kwargs # must carry n_features=3 for generators that aren't already 3-D, since # they'll be forwarded verbatim to the Prefect flow's generator_kwargs. DATASET_PREVIEW_N = 5000 DATASET_PREVIEW_SEED = 0 DATASET_META: Dict[str, Dict[str, Any]] = { "s_curve": { "name": "S-Curve", "path": "sklearn.datasets.make_s_curve", "kwargs": {}, "description": ( "A 2-D manifold warped into R³. Continuous label encodes position " "along the curve — a good test of whether a reducer unrolls the " "sheet without tearing." ), "kind": "continuous", }, "swiss_roll": { "name": "Swiss Roll", "path": "sklearn.datasets.make_swiss_roll", "kwargs": {}, "description": ( "A rolled-up plane. The canonical hard case for linear methods: " "PCA collapses the spiral, non-linear methods should recover the " "unroll." ), "kind": "continuous", }, "swiss_roll_hole": { "name": "Swiss Roll (hole)", "path": "sklearn.datasets.make_swiss_roll", "kwargs": {"hole": True}, "description": ( "Swiss roll with a rectangular hole punched through. Same manifold, " "non-trivial topology — a faithful unroll should preserve the hole " "rather than smearing it closed." ), "kind": "continuous", }, "blobs": { "name": "Gaussian Blobs", "path": "sklearn.datasets.make_blobs", "kwargs": {"n_features": 3, "centers": 5, "cluster_std": 1.0}, "description": ( "Five isotropic Gaussian clusters in R³. Discrete class labels. " "Tests whether a reducer preserves cluster separation when " "projected to 2-D." ), "kind": "categorical", }, "gaussian_quantiles": { "name": "Gaussian Quantiles", "path": "sklearn.datasets.make_gaussian_quantiles", "kwargs": {"n_features": 3, "n_classes": 4}, "description": ( "Concentric Gaussian shells in R³; class = which shell. Classes " "are linearly inseparable by construction — PCA collapses them, " "kernel and manifold methods have a chance." ), "kind": "categorical", }, "classification": { "name": "Hypercube Clusters", "path": "sklearn.datasets.make_classification", "kwargs": { "n_features": 3, "n_informative": 3, "n_redundant": 0, "n_repeated": 0, "n_classes": 4, "n_clusters_per_class": 2, "class_sep": 1.5, }, "description": ( "Four classes, two sub-clusters each, placed at hypercube vertices " "with informative noise. A denser discrete test than blobs — " "within-class bimodality stresses cluster-preserving reducers." ), "kind": "categorical", }, } @lru_cache(maxsize=1) def _dataset_previews() -> Dict[str, Dict[str, Any]]: """Attach freshly-generated points+labels to the catalogue for the picker.""" N, SEED = DATASET_PREVIEW_N, DATASET_PREVIEW_SEED s, sl = make_s_curve(n_samples=N, noise=0.03, random_state=SEED) sr, srl = make_swiss_roll(n_samples=N, noise=0.15, random_state=SEED) srh, srhl = make_swiss_roll(n_samples=N, noise=0.15, hole=True, random_state=SEED) b, bl = make_blobs( n_samples=N, n_features=3, centers=5, cluster_std=1.0, random_state=SEED ) gq, gql = make_gaussian_quantiles( n_samples=N, n_features=3, n_classes=4, random_state=SEED ) cls, clsl = make_classification( n_samples=N, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=4, n_clusters_per_class=2, class_sep=1.5, random_state=SEED, ) samples = { "s_curve": (s, sl), "swiss_roll": (sr, srl), "swiss_roll_hole": (srh, srhl), "blobs": (b, bl), "gaussian_quantiles": (gq, gql), "classification": (cls, clsl), } out: Dict[str, Dict[str, Any]] = {} for key, meta in DATASET_META.items(): pts, labels = samples[key] out[key] = { **meta, "points": pts.tolist(), "labels": labels.tolist(), } return out # --------------------------------------------------------------------------- # Reducer catalogue # --------------------------------------------------------------------------- # Each field tuple: (name, kind, default, choices_or_none, help_or_none) # kinds: "int", "float", "str", "bool", "str_or_float", "int_or_null" REDUCERS: Dict[str, Dict[str, Any]] = { "sklearn.decomposition.PCA": { "pkg": "sklearn", "label": "PCA", "blurb": "Principal component analysis. Linear, fast, deterministic.", "key": [ ("n_components", "int", 2, None, "Locked."), ], "advanced": [ ("svd_solver", "str", "auto", ["auto", "full", "arpack", "randomized"], None), ("random_state", "int", 42, None, None), ("whiten", "bool", False, None, None), ], }, "sklearn.decomposition.FactorAnalysis": { "pkg": "sklearn", "label": "FactorAnalysis", "blurb": "Gaussian latent-factor model with per-feature noise.", "key": [ ("n_components", "int", 2, None, "Locked."), ("random_state", "int", 42, None, None), ], "advanced": [ ("tol", "float", 0.01, None, None), ("max_iter", "int", 1000, None, None), ("rotation", "str", "", ["", "varimax", "quartimax"], "Empty = None."), ], }, "sklearn.decomposition.KernelPCA": { "pkg": "sklearn", "label": "KernelPCA", "blurb": "Non-linear PCA via the kernel trick. Deterministic; kernel choice shapes the output.", "key": [ ("n_components", "int", 2, None, "Locked."), ("kernel", "str", "rbf", ["linear", "poly", "rbf", "sigmoid", "cosine"], None), ("random_state", "int", 42, None, None), ], "advanced": [ ("gamma", "str_or_float", "", None, "Empty = 1/n_features."), ("degree", "int", 3, None, None), ("coef0", "float", 1.0, None, None), ("alpha", "float", 1.0, None, None), ], }, "sklearn.manifold.Isomap": { "pkg": "sklearn", "label": "Isomap", "blurb": "Geodesic-distance manifold learning via shortest paths on a k-NN graph.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 5, None, None), ], "advanced": [ ("metric", "str", "minkowski", None, None), ("p", "int", 2, None, "Minkowski power (1 = manhattan, 2 = euclidean)."), ("path_method", "str", "auto", ["auto", "FW", "D"], "Floyd-Warshall / Dijkstra / auto."), ("neighbors_algorithm", "str", "auto", ["auto", "ball_tree", "kd_tree", "brute"], None), ], }, "sklearn.manifold.MDS": { "pkg": "sklearn", "label": "MDS", "blurb": "Multidimensional scaling. Preserves pairwise distances; O(n²) memory.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_init", "int", 4, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("max_iter", "int", 300, None, None), ("metric_mds", "bool", True, None, "Metric (True) vs non-metric MDS."), ("metric", "str", "euclidean", None, None), ("eps", "float", 1e-6, None, "Convergence tolerance."), ], }, "sklearn.manifold.SpectralEmbedding": { "pkg": "sklearn", "label": "SpectralEmbedding", "blurb": "Laplacian eigenmaps on an affinity graph. What UMAP uses for initialisation.", "key": [ ("n_components", "int", 2, None, "Locked."), ("affinity", "str", "nearest_neighbors", ["nearest_neighbors", "rbf"], None), ("random_state", "int", 42, None, None), ], "advanced": [ ("n_neighbors", "int_or_null", "", None, "For affinity=nearest_neighbors. Empty = n/10."), ("gamma", "str_or_float", "", None, "For affinity=rbf. Empty = 1/n_features."), ], }, "sklearn.manifold.TSNE": { "pkg": "sklearn", "label": "t-SNE", "blurb": "Stochastic neighbour embedding. Local structure preserved.", "key": [ ("n_components", "int", 2, None, "Locked."), ("perplexity", "float", 30.0, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("learning_rate", "str_or_float", "auto", None, "'auto' or a float."), ("n_iter", "int", 1000, None, None), ("metric", "str", "euclidean", None, None), ("early_exaggeration", "float", 12.0, None, None), ("init", "str", "pca", ["pca", "random"], None), ], }, "umap.UMAP": { "pkg": "umap", "label": "UMAP", "blurb": "Uniform manifold approximation. Preserves local + some global structure.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 15, None, None), ("min_dist", "float", 0.1, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("metric", "str", "euclidean", None, None), ("n_epochs", "int_or_null", "", None, "Empty = None (auto)."), ("spread", "float", 1.0, None, None), ("init", "str", "spectral", ["spectral", "random"], None), ], }, "pacmap.PaCMAP": { "pkg": "pacmap", "label": "PaCMAP", "blurb": "Pairwise-controlled manifold approximation. Balanced local/global.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 10, None, None), ("MN_ratio", "float", 0.5, None, None), ("FP_ratio", "float", 2.0, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("lr", "float", 1.0, None, None), ("num_iters", "int", 450, None, None), ("apply_pca", "bool", True, None, None), ], }, "pacmap.LocalMAP": { "pkg": "pacmap", "label": "LocalMAP", "blurb": "PaCMAP variant with a low-distance threshold; sharper local structure.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 10, None, None), ("MN_ratio", "float", 0.5, None, None), ("FP_ratio", "float", 2.0, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("lr", "float", 1.0, None, None), ("num_iters", "int", 450, None, None), ("apply_pca", "bool", True, None, None), ("low_dist_thres", "float", 10.0, None, None), ], }, "trimap.TRIMAP": { "pkg": "trimap", "label": "TriMap", "blurb": "Triplet-based dimensionality reduction. Emphasises global structure.", "key": [ ("n_dims", "int", 2, None, "Locked."), ("n_inliers", "int", 10, None, None), ("n_outliers", "int", 5, None, None), ("n_random", "int", 5, None, None), ], "advanced": [ ("lr", "float", 0.1, None, None), ("n_iters", "int", 400, None, None), ("weight_adj", "float", 500.0, None, None), ], }, "sklearn.random_projection.GaussianRandomProjection": { "pkg": "sklearn", "label": "GaussianRandomProjection", "blurb": "Johnson-Lindenstrauss baseline. Cheap, distance-preserving in expectation, structure-agnostic.", "key": [ ("n_components", "int", 2, None, "Locked."), ("random_state", "int", 42, None, None), ], "advanced": [], }, } def available_reducers() -> List[Tuple[str, Dict[str, Any]]]: out = [] for key, spec in REDUCERS.items(): if importlib.util.find_spec(spec["pkg"]) is not None: out.append((key, spec)) return out # --------------------------------------------------------------------------- # Parameter coercion # --------------------------------------------------------------------------- def _coerce(kind: str, raw: Optional[str], default: Any) -> Any: if raw is None: return default s = raw.strip() if isinstance(raw, str) else raw if kind == "int": if s == "" or s is None: return default return int(s) if kind == "float": if s == "" or s is None: return default return float(s) if kind == "bool": # Checkbox: "on" / absent return bool(s) and s not in ("0", "false", "False", "") if kind == "str": if s == "": return None if default in (None, "") else default if default else "" return s if kind == "str_or_float": if s == "": return default try: return float(s) except (ValueError, TypeError): return s if kind == "int_or_null": if s == "": return None return int(s) return s def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]: spec = REDUCERS[reducer_key] out: Dict[str, Any] = {} all_fields = list(spec["key"]) + list(spec["advanced"]) for (name, kind, default, _choices, _help) in all_fields: raw = form.get(f"embed__{name}") if kind == "bool": raw_v = "on" if f"embed__{name}" in form else "" value = bool(raw_v) else: value = _coerce(kind, raw, default) # Null-stripping: drop empty rotations etc. if value is None: continue if isinstance(value, str) and value == "" and default in (None, ""): continue out[name] = value # Always force n_components / n_dims to 2 (flow assertion) if "n_components" in out: out["n_components"] = 2 if "n_dims" in out: out["n_dims"] = 2 return out # --------------------------------------------------------------------------- # Output-path synthesis (mirrors flows/embedding_flow.py lines ~162–168) # --------------------------------------------------------------------------- def run_args_hash( embed_args: Optional[Dict[str, Any]], generator_kwargs: Optional[Dict[str, Any]] = None, ) -> str: """8-hex digest of (embed_args, generator_kwargs). When generator_kwargs is empty/None we hash embed_args alone — preserves stems for the plain generators (s_curve, plain swiss_roll) that never had gen_kwargs. For kwargs-bearing variants (swiss_roll_hole, blobs, gaussian_quantiles, classification), the hash now disambiguates them from their kwargs-less siblings — run scripts/backfill_hashes.py to rehash existing figs.""" if generator_kwargs: payload: Any = { "embed_args": embed_args or {}, "generator_kwargs": generator_kwargs, } else: payload = embed_args or {} s = json.dumps(payload, sort_keys=True, default=str) return hashlib.sha1(s.encode()).hexdigest()[:8] # Back-compat alias — some call sites passed only embed_args. embed_args_hash = run_args_hash def sci_notation(v: Any) -> str: """Float → compact sci notation without a period (0.005 → '5E-3'). Used in stems and Prefect run names so filenames + UI avoid periods.""" try: f = float(v) except (TypeError, ValueError): return str(v) m, e = f"{f:.3e}".split("e") m = m.rstrip("0").rstrip(".") return f"{m}E{int(e)}" def synthesize_output_paths( generator_path: str, embedder: str, num_points: int, num_timesteps: int, jitter_scale: float, seed: int, embed_args: Optional[Dict[str, Any]] = None, generator_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[str, str]: gen = generator_path.split(".")[-1] emb = embedder.split(".")[-1] j = sci_notation(jitter_scale) ref = f"{gen}_Reference_N{num_points}_T{num_timesteps}_J{j}_s{seed}.html" base = f"{gen}_{emb}_N{num_points}_T{num_timesteps}_J{j}_s{seed}" if embed_args is None: embf = f"{base}.html" else: embf = f"{base}_{run_args_hash(embed_args, generator_kwargs)}.html" return ref, embf def _resolve_emb_file(synthesized: str) -> str: """Disk-backed fallback: prefer the synthesized (hashed) name; if that doesn't exist on disk but an older hash-less variant does, return that so pre-hash runs still render in the UI.""" if (FIGS_DIR / synthesized).exists(): return synthesized # Strip trailing _<8hex>.html to get the legacy name. m = re.match(r"^(?P.+)_[0-9a-f]{8}\.html$", synthesized) if m: legacy = m.group("base") + ".html" if (FIGS_DIR / legacy).exists(): return legacy return synthesized # new / still-running run; let emb_exists resolve # --------------------------------------------------------------------------- # Prefect client # --------------------------------------------------------------------------- class Prefect: def __init__(self, base: str = PREFECT_API) -> None: self.base = base.rstrip("/") self._deployment_id: Optional[str] = None async def deployment_id(self, client: httpx.AsyncClient) -> Optional[str]: if self._deployment_id: return self._deployment_id try: r = await client.get(f"{self.base}/deployments/name/{DEPLOYMENT_NAME}") if r.status_code == 200: self._deployment_id = r.json()["id"] return self._deployment_id except httpx.HTTPError: return None return None async def create_run( self, client: httpx.AsyncClient, parameters: Dict[str, Any], tags: Optional[List[str]] = None, ) -> Optional[Dict[str, Any]]: dep = await self.deployment_id(client) if not dep: return None body: Dict[str, Any] = {"parameters": parameters} if tags: body["tags"] = list(tags) r = await client.post( f"{self.base}/deployments/{dep}/create_flow_run", json=body, ) if r.status_code >= 400: return {"error": r.text, "status": r.status_code} return r.json() async def recent_runs( self, client: httpx.AsyncClient, limit: int = 10, required_tags: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: dep = await self.deployment_id(client) if not dep: return [] flow_runs: Dict[str, Any] = {"deployment_id": {"any_": [dep]}} if required_tags: flow_runs["tags"] = {"all_": list(required_tags)} # Prefect rejects limit > 200 with HTTP 422. capped = min(max(1, limit), 200) try: r = await client.post( f"{self.base}/flow_runs/filter", json={ "sort": "START_TIME_DESC", "limit": capped, "flow_runs": flow_runs, }, ) if r.status_code == 200: return r.json() except httpx.HTTPError: return [] return [] async def update_tags( self, client: httpx.AsyncClient, run_id: str, tags: List[str] ) -> bool: try: r = await client.patch( f"{self.base}/flow_runs/{run_id}", json={"tags": list(tags)}, ) return r.status_code < 400 except httpx.HTTPError: return False return [] PREFECT = Prefect() # --------------------------------------------------------------------------- # In-memory mapping: flow_run_id -> synthesized output file names # (best-effort; lost on restart, which is fine) # --------------------------------------------------------------------------- RUN_OUTPUTS: Dict[str, Dict[str, str]] = {} # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI(title="web1 — embedding notebook", docs_url=None, redoc_url=None) app.mount("/figs", StaticFiles(directory=str(FIGS_DIR)), name="figs") app.mount("/static", StaticFiles(directory=str(BASE_DIR / "static")), name="static") templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) def _fmt_runtime(seconds: Optional[float]) -> Optional[str]: if seconds is None or seconds <= 0: return None if seconds < 60: return f"{seconds:.1f}s" m, s = divmod(int(seconds), 60) if m < 60: return f"{m}m{s:02d}s" h, m = divmod(m, 60) return f"{h}h{m:02d}m" def _run_view(run: Dict[str, Any]) -> Dict[str, Any]: """Normalise a flow-run dict for the template.""" rid = run.get("id", "") state_type = (run.get("state_type") or "PENDING").upper() state_name = run.get("state_name") or state_type.title() # Prefect labels a scheduled run that's past its start time "Late". The # runner is simply busy / awaiting its turn at the concurrency cap — show # "Queued" instead, which is the honest description of the state. if state_name == "Late": state_name = "Queued" start = run.get("start_time") or run.get("expected_start_time") or run.get("created") params = run.get("parameters") or {} # estimated_run_time recomputes server-side (ticks up while RUNNING); # total_run_time is the authoritative value once the run finishes. rtime_s = run.get("estimated_run_time") or run.get("total_run_time") # Try to look up synthesised outputs either from memory or from params ref_file = None emb_file = None outs = RUN_OUTPUTS.get(rid) if outs: ref_file = outs["ref"] emb_file = outs["embed"] elif params: try: ref_file, emb_file = synthesize_output_paths( params.get("generator_path", "sklearn.datasets.make_s_curve"), params.get("embedder", "sklearn.decomposition.FactorAnalysis"), int(params.get("num_points", 5000)), # Fallback to the old num_snapshots key for runs dispatched # before the T-rename, so historical figs still resolve after # `rename 's/_S/_T/' figs/*.html`. int(params.get("num_timesteps", params.get("num_snapshots", 48))), float(params.get("jitter_scale", 0.01)), int(params.get("seed", 42)), embed_args=params.get("embed_args") or {}, generator_kwargs=params.get("generator_kwargs") or {}, ) # Older runs may lack the hash suffix; prefer legacy name on disk. emb_file = _resolve_emb_file(emb_file) except Exception: ref_file, emb_file = None, None ref_exists = bool(ref_file) and (FIGS_DIR / ref_file).exists() emb_exists = bool(emb_file) and (FIGS_DIR / emb_file).exists() return { "id": rid, "short_id": rid[:8] if rid else "", "name": run.get("name", ""), "state_type": state_type, "state_name": state_name, "start": start, "runtime": _fmt_runtime(float(rtime_s) if rtime_s is not None else None), "params": params, "ref_file": ref_file, "emb_file": emb_file, "ref_exists": ref_exists, "emb_exists": emb_exists, "embedder_short": (params.get("embedder") or "").split(".")[-1], "generator_short": _dataset_id( params.get("generator_path") or "", params.get("generator_kwargs") or {} ), } def _mark_stale_views(views: List[Dict[str, Any]]) -> None: """Flag runs whose emb HTML was overwritten by a newer sibling with identical params. Prefect returns runs sorted by START_TIME_DESC, so the first occurrence of each stem is authoritative; later ones are stale. Mutates views in place.""" seen: set = set() for v in views: stem = v["emb_file"][:-5] if v.get("emb_file") else None if stem and stem in seen: v["stale"] = True else: v["stale"] = False if stem: seen.add(stem) def _reducer_choices() -> List[Dict[str, str]]: return [ {"key": k, "label": spec["label"], "blurb": spec["blurb"]} for k, spec in available_reducers() ] # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def index(request: Request) -> HTMLResponse: reducers = _reducer_choices() default_reducer = reducers[0]["key"] if reducers else None default_spec = REDUCERS.get(default_reducer) if default_reducer else None q = request.query_params required = _chip_filter_tags(q) initial_limit = 50 if required else 10 async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=initial_limit, required_tags=required) dep_id = await PREFECT.deployment_id(client) views = [_run_view(r) for r in runs] _mark_stale_views(views) # Pre-resolve the two
sections' open state from the URL so # first paint matches (no flash). Intro defaults closed, picker open. intro_open = q.get("intro") == "1" picker_open = q.get("picker") != "0" # Also pre-resolve the radio-group selections so n/f/j render with the # correct `checked` attribute on first paint. initial_radios = { "n": q.get("n") or "500", "f": q.get("f") or "24", "j": q.get("j") or "0.005", } return templates.TemplateResponse( request, "index.html", { "reducers": reducers, "default_reducer": default_reducer, "default_spec": default_spec, "runs": views, "deployment_id": dep_id, "prefect_api": PREFECT_API, "intro_open": intro_open, "picker_open": picker_open, "initial_radios": initial_radios, }, ) @app.get("/data.json") async def data_json() -> JSONResponse: return JSONResponse(_dataset_previews()) @app.get("/reducer-form", response_class=HTMLResponse) async def reducer_form(request: Request, name: str) -> HTMLResponse: spec = REDUCERS.get(name) if not spec: return HTMLResponse("

unknown reducer

", status_code=404) return templates.TemplateResponse( request, "_reducer_form.html", {"reducer_key": name, "spec": spec}, ) def _chip_filter_tags(params) -> List[str]: """Turn chip-filter query params (?dataset=…&algorithm=…&N=…&T=…&J=…) into a Prefect `tags all_` list. Empty / missing values skip the axis.""" keys = ("dataset", "algorithm", "N", "T", "J") tags = [] for k in keys: v = (params.get(k) or "").strip() if v: tags.append(f"{k}:{v}") return tags @app.get("/runs/axes.json") async def runs_axes() -> JSONResponse: """Distinct chip values across the last N deployment-scoped runs. Lets the chip bar show the full universe regardless of the current filter.""" async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=500) values: Dict[str, set] = {k: set() for k in ("dataset", "algorithm", "N", "T", "J")} for r in runs: for tag in r.get("tags") or []: if ":" not in tag: continue k, _, v = tag.partition(":") if k in values: values[k].add(v) # Sort numeric axes numerically. def _sort(k, vs): if k in ("N", "T", "J"): return sorted(vs, key=lambda x: float(x) if x else 0.0) return sorted(vs) return JSONResponse({k: _sort(k, v) for k, v in values.items()}) @app.get("/runs", response_class=HTMLResponse) async def runs_partial(request: Request) -> HTMLResponse: required = _chip_filter_tags(request.query_params) # Server-side tag filter → one narrow query per chip state. When any # axis is unfiltered, Prefect returns the K most recent for that slice; # when fully filtered, usually a handful of exact matches. limit = 50 if required else 10 async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=limit, required_tags=required) views = [_run_view(r) for r in runs] _mark_stale_views(views) return templates.TemplateResponse( request, "_runs.html", {"runs": views} ) @app.post("/submit", response_class=HTMLResponse) async def submit(request: Request) -> HTMLResponse: form = await request.form() data: Dict[str, str] = {k: str(v) for k, v in form.items()} reducer = data.get("reducer") or "" if reducer not in REDUCERS: return HTMLResponse( f"
unknown reducer: {reducer}
", status_code=400, ) # Dataset came from the picker via dataset_id; fall back to explicit # generator_path / generator_kwargs only when dataset_id is absent entirely # (API consumers). UI form posts always carry the key, so an empty value # means the user hit submit without picking — reject rather than silently # defaulting to s_curve. if "dataset_id" in data: dataset_id = data.get("dataset_id") or "" if not dataset_id: return HTMLResponse( "
pick a dataset first (§ 1 above)
", status_code=400, ) if dataset_id not in DATASET_META: return HTMLResponse( f"
unknown dataset: {dataset_id}
", status_code=400, ) meta = DATASET_META[dataset_id] generator_path = meta["path"] generator_kwargs = dict(meta["kwargs"]) else: generator_path = data.get("generator_path") or "" if not generator_path: return HTMLResponse( "
missing dataset_id or generator_path
", status_code=400, ) raw_kwargs = data.get("generator_kwargs") or "" try: generator_kwargs = json.loads(raw_kwargs) if raw_kwargs else {} except json.JSONDecodeError as e: return HTMLResponse( f"
bad generator_kwargs JSON: {e}
", status_code=400, ) try: num_points = int(data.get("num_points", "5000") or 5000) num_timesteps = int(data.get("num_timesteps", "48") or 48) jitter_scale = float(data.get("jitter_scale", "0.01") or 0.01) seed = int(data.get("seed", "42") or 42) except ValueError as e: return HTMLResponse( f"
bad numeric input: {e}
", status_code=400 ) embed_args = build_embed_args(reducer, data) # Reject submissions whose output path would overwrite an existing fig. # Hash now covers both embed_args and generator_kwargs, so swiss_roll vs # swiss_roll_hole (and blobs with varying n_features, etc.) no longer # share a stem. Also check the legacy hashless path for pre-hash figs. _, hashed_emb = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, embed_args=embed_args, generator_kwargs=generator_kwargs, ) _, legacy_emb = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, ) for candidate in (hashed_emb, legacy_emb): if (FIGS_DIR / candidate).exists(): return HTMLResponse( f"
a run with matching params already " f"exists ({candidate}). change a param or delete " f"the fig first.
", status_code=409, ) parameters: Dict[str, Any] = { "num_points": num_points, "num_timesteps": num_timesteps, "jitter_scale": jitter_scale, "seed": seed, "generator_path": generator_path, "embedder": reducer, "embed_args": embed_args, } if generator_kwargs: parameters["generator_kwargs"] = generator_kwargs tags = build_run_tags( generator_path, generator_kwargs, reducer, num_points, num_timesteps, jitter_scale, ) async with httpx.AsyncClient(timeout=10.0) as client: run = await PREFECT.create_run(client, parameters, tags=tags) if not run: return HTMLResponse( "
could not reach Prefect API at " f"{PREFECT_API}
", status_code=502, ) if "error" in run: return HTMLResponse( f"
prefect error ({run.get('status')}): " f"{run.get('error')[:500]}
", status_code=502, ) ref_file, emb_file = synthesize_output_paths( generator_path, reducer, num_points, num_timesteps, jitter_scale, seed, embed_args=embed_args, generator_kwargs=generator_kwargs, ) RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file} # Return freshly refreshed runs partial so htmx can swap the right column async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=10) views = [_run_view(r) for r in runs] _mark_stale_views(views) return templates.TemplateResponse( request, "_runs.html", {"runs": views, "just_submitted": run["id"]}, ) def _scan_metrics() -> List[Dict[str, Any]]: """Read every `*.metrics.json` in FIGS_DIR and return them as a list.""" out: List[Dict[str, Any]] = [] for p in sorted(FIGS_DIR.glob("*.metrics.json"), key=lambda p: p.stat().st_mtime, reverse=True): try: data = json.loads(p.read_text()) except (OSError, json.JSONDecodeError): continue data["filename"] = p.name data["embedding_file"] = p.name.replace(".metrics.json", ".html") out.append(data) return out @app.get("/metrics", response_class=HTMLResponse) async def metrics_page(request: Request) -> HTMLResponse: async with httpx.AsyncClient(timeout=5.0) as client: dep_id = await PREFECT.deployment_id(client) return templates.TemplateResponse( request, "metrics.html", {"prefect_api": PREFECT_API, "deployment_id": dep_id}, ) @app.get("/metrics.json") async def metrics_json() -> JSONResponse: return JSONResponse(_scan_metrics()) _STEM_RE = re.compile( r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.Ee+\-]+_s\d+(?:_[0-9a-f]{8})?$" ) # Map short generator name ("make_blobs") to its DATASET_META entry. # swiss_roll and swiss_roll_hole collide on path; first wins (plain variant). _GEN_TO_META: Dict[str, Dict[str, Any]] = {} for _m in DATASET_META.values(): _GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m) # Kwargs the flow injects / we supply explicitly — never part of the # dataset's semantic identity, so strip them before DATASET_META matching # and before regenerating labels. _TRANSIENT_GEN_KWARGS = {"n_samples", "random_state"} def _clean_gen_kwargs(gk: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if gk is None: return None return {k: v for k, v in gk.items() if k not in _TRANSIENT_GEN_KWARGS} # Tag axes the chip-filter and backfill both care about. Keep as # (short_prefix, builder) pairs so adding an axis is a one-line change. TAG_AXES = ("dataset", "algorithm", "N", "T", "J") def build_run_tags( generator_path: str, generator_kwargs: Optional[Dict[str, Any]], embedder: str, num_points: int, num_timesteps: int, jitter_scale: float, ) -> List[str]: """Tags written onto every flow run so the chip filter can narrow server-side via Prefect's tag:all_ filter. Single value per axis; the client's cassette chips pick exactly one per filter.""" return [ f"dataset:{_dataset_id(generator_path, generator_kwargs)}", f"algorithm:{(embedder or '').rsplit('.', 1)[-1]}", f"N:{int(num_points)}", f"T:{int(num_timesteps)}", f"J:{sci_notation(jitter_scale)}", ] def _dataset_id(generator_path: str, generator_kwargs: Optional[Dict[str, Any]]) -> str: """Human-scale identifier for a run's dataset — e.g. 'swiss_roll' vs 'swiss_roll_hole' — by matching (path, cleaned kwargs) against DATASET_META. Falls back to the path short-name when no match.""" gen_short = (generator_path or "").rsplit(".", 1)[-1] gk = _clean_gen_kwargs(generator_kwargs) candidates = [ (k, m) for k, m in DATASET_META.items() if m["path"].rsplit(".", 1)[-1] == gen_short ] if not candidates: return gen_short if gk is not None: for k, m in candidates: if m["kwargs"] == gk: return k return candidates[0][0] def _lookup_dataset_meta( generator_short: str, generator_kwargs: Optional[Dict[str, Any]] ) -> Optional[Dict[str, Any]]: """Match DATASET_META by generator short-name AND kwargs when available. Falls back to first-wins when kwargs are unknown (ambiguous for swiss_roll vs swiss_roll_hole — both share `make_swiss_roll`).""" candidates = [ m for m in DATASET_META.values() if m["path"].rsplit(".", 1)[-1] == generator_short ] if not candidates: return None gk = _clean_gen_kwargs(generator_kwargs) if gk is not None: for m in candidates: if m["kwargs"] == gk: return m return candidates[0] def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]: """Attach per-point class/continuous labels by regenerating the dataset with the same (generator, n_samples, kwargs). random_state is fixed at 0 (the flow's _DEFAULT_GENERATOR_KWARGS) — the stem's `seed` drives jitter, not the generator. Jitter-added points (id >= num_points) get None so the client renders them as black. Discovers generator_kwargs in priority order: (1) payload meta (sidecar runs from the updated flow); (2) sibling metrics.json; (3) DATASET_META by first-match (ambiguous for swiss_roll/swiss_roll_hole — need a backfilled metrics.json to disambiguate).""" meta = d.get("meta") or {} gen_short = meta.get("generator") or "" gk = meta.get("generator_kwargs") if gk is None: stem = meta.get("stem") if stem: mx = FIGS_DIR / f"{stem}.metrics.json" if mx.is_file(): try: gk = json.loads(mx.read_text(encoding="utf-8")).get( "meta", {} ).get("generator_kwargs") except Exception: gk = None dm = _lookup_dataset_meta(gen_short, gk) if not dm: return d # Replace the stem-derived generator short (ambiguous for swiss_roll vs # hole) with the matched DATASET_META id for the panel header. for key, entry in DATASET_META.items(): if entry is dm: d["meta"]["generator"] = key break kwargs_to_use = _clean_gen_kwargs(gk) if gk is not None else dm["kwargs"] try: mod_path, cls_name = dm["path"].rsplit(".", 1) fn = getattr(importlib.import_module(mod_path), cls_name) N = int(meta["num_points"]) _, gen_labels = fn(n_samples=N, random_state=0, **kwargs_to_use) out_labels: List[Optional[float]] = [] for pid in d["point_ids"]: if isinstance(pid, int) and 0 <= pid < N: v = gen_labels[pid] out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None) else: out_labels.append(None) d["labels"] = out_labels d["label_kind"] = dm["kind"] except Exception: pass return d @lru_cache(maxsize=32) def _cached_frames(stem: str) -> str: """Return the frames dict as a JSON string. Prefers a .frames.json sidecar (emitted by new flow runs); falls back to parsing .html (for pre-sidecar runs). Either way, enriches with dataset labels.""" sidecar = FIGS_DIR / f"{stem}.frames.json" if sidecar.is_file(): d = json.loads(sidecar.read_text(encoding="utf-8")) else: html = FIGS_DIR / f"{stem}.html" d = parse_plotly_run(html) # Override meta.stem with the URL-requested stem — after a backfill the # file was renamed but the baked-in meta.stem still points at the old # name. Enrichment uses this to find the sibling metrics.json. d.setdefault("meta", {})["stem"] = stem d = _enrich_with_labels(d) return json.dumps(d, separators=(",", ":")) @app.get("/api/runs/{stem}/frames.json") async def run_frames(stem: str) -> Response: if not _STEM_RE.match(stem): raise HTTPException(400, f"malformed stem: {stem!r}") if not (FIGS_DIR / f"{stem}.frames.json").is_file() and not (FIGS_DIR / f"{stem}.html").is_file(): raise HTTPException(404, f"no such run: {stem}") try: payload = _cached_frames(stem) except Exception as e: raise HTTPException(500, f"parse failed: {e}") return Response( content=payload, media_type="application/json", headers={"Cache-Control": "no-cache"}, ) @app.get("/compare", response_class=HTMLResponse) async def compare_page(request: Request) -> HTMLResponse: q = request.query_params stems = [s for s in q.getlist("stem") if s] if not stems: # Legacy two-stem form: ?a=&b= stems = [s for s in (q.get("a", ""), q.get("b", "")) if s] if not (2 <= len(stems) <= 8): raise HTTPException(400, f"need 2..8 stems, got {len(stems)}") for stem in stems: if not _STEM_RE.match(stem): raise HTTPException(400, f"malformed stem: {stem!r}") has_sidecar = (FIGS_DIR / f"{stem}.frames.json").is_file() has_html = (FIGS_DIR / f"{stem}.html").is_file() if not (has_sidecar or has_html): raise HTTPException(404, f"no such run: {stem}") return templates.TemplateResponse( request, "compare.html", {"stems": stems} ) @app.get("/health") async def health() -> JSONResponse: async with httpx.AsyncClient(timeout=3.0) as client: dep = await PREFECT.deployment_id(client) return JSONResponse( {"ok": True, "deployment_id": dep, "prefect_api": PREFECT_API} )