dr-sandbox/app/web/main.py
Michael Pilosov 9b178dad38 runs: filter chips + compare selection up to 8
- /compare accepts ?stem=…&stem=… (repeated) for 2-8 runs; legacy ?a=&b=
  still works. compare.js parses multi-stem; template drops stem_a/_b
  data attrs that were unused.
- compare-select.js: MAX bumped to 8, button enables at 2-8 selected.
  URL emitted as ?stem=… per selection.
- runs list gets a dataset/algorithm chip filter bar above #runs-slot
  (pattern ported from metrics.js). Chips reflect the union of values in
  the current list; selection state persists across htmx swaps. Non-
  matching rows get .filtered-out (display:none).
- _runs.html li now carries data-embedder/data-generator so the filter
  can key on them.
2026-04-22 16:41:06 -06:00

1050 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
web1 — "Scientific instrument / research notebook"
A FastAPI UI for kicking off the embedding-flow Prefect deployment and
viewing the resulting HTML animations.
Design: restrained, typography-driven, two-column notebook layout. No CSS
framework; hand-written styles.
"""
from __future__ import annotations
import hashlib
import importlib.util
import json
import os
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from app.web.plotly_parse import parse_plotly_run
import httpx
from fastapi import FastAPI, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sklearn.datasets import (
make_blobs,
make_classification,
make_gaussian_quantiles,
make_s_curve,
make_swiss_roll,
)
# ---------------------------------------------------------------------------
# Paths / constants
# ---------------------------------------------------------------------------
BASE_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = BASE_DIR.parent.parent # /home/mm/work/dr-sandbox
FIGS_DIR = PROJECT_ROOT / "figs"
FIGS_DIR.mkdir(parents=True, exist_ok=True)
PREFECT_API = os.environ.get("PREFECT_API_URL", "http://localhost:4200/api")
DEPLOYMENT_NAME = "embedding-flow/embedding-flow"
# ---------------------------------------------------------------------------
# Dataset catalogue
# ---------------------------------------------------------------------------
# Metadata for the /data.json endpoint consumed by the dataset picker, and
# for server-side lookup when the picker posts its selection back. kwargs
# must carry n_features=3 for generators that aren't already 3-D, since
# they'll be forwarded verbatim to the Prefect flow's generator_kwargs.
DATASET_PREVIEW_N = 5000
DATASET_PREVIEW_SEED = 0
DATASET_META: Dict[str, Dict[str, Any]] = {
"s_curve": {
"name": "S-Curve",
"path": "sklearn.datasets.make_s_curve",
"kwargs": {},
"description": (
"A 2-D manifold warped into R³. Continuous label encodes position "
"along the curve — a good test of whether a reducer unrolls the "
"sheet without tearing."
),
"kind": "continuous",
},
"swiss_roll": {
"name": "Swiss Roll",
"path": "sklearn.datasets.make_swiss_roll",
"kwargs": {},
"description": (
"A rolled-up plane. The canonical hard case for linear methods: "
"PCA collapses the spiral, non-linear methods should recover the "
"unroll."
),
"kind": "continuous",
},
"swiss_roll_hole": {
"name": "Swiss Roll (hole)",
"path": "sklearn.datasets.make_swiss_roll",
"kwargs": {"hole": True},
"description": (
"Swiss roll with a rectangular hole punched through. Same manifold, "
"non-trivial topology — a faithful unroll should preserve the hole "
"rather than smearing it closed."
),
"kind": "continuous",
},
"blobs": {
"name": "Gaussian Blobs",
"path": "sklearn.datasets.make_blobs",
"kwargs": {"n_features": 3, "centers": 5, "cluster_std": 1.0},
"description": (
"Five isotropic Gaussian clusters in R³. Discrete class labels. "
"Tests whether a reducer preserves cluster separation when "
"projected to 2-D."
),
"kind": "categorical",
},
"gaussian_quantiles": {
"name": "Gaussian Quantiles",
"path": "sklearn.datasets.make_gaussian_quantiles",
"kwargs": {"n_features": 3, "n_classes": 4},
"description": (
"Concentric Gaussian shells in R³; class = which shell. Classes "
"are linearly inseparable by construction — PCA collapses them, "
"kernel and manifold methods have a chance."
),
"kind": "categorical",
},
"classification": {
"name": "Hypercube Clusters",
"path": "sklearn.datasets.make_classification",
"kwargs": {
"n_features": 3,
"n_informative": 3,
"n_redundant": 0,
"n_repeated": 0,
"n_classes": 4,
"n_clusters_per_class": 2,
"class_sep": 1.5,
},
"description": (
"Four classes, two sub-clusters each, placed at hypercube vertices "
"with informative noise. A denser discrete test than blobs — "
"within-class bimodality stresses cluster-preserving reducers."
),
"kind": "categorical",
},
}
@lru_cache(maxsize=1)
def _dataset_previews() -> Dict[str, Dict[str, Any]]:
"""Attach freshly-generated points+labels to the catalogue for the picker."""
N, SEED = DATASET_PREVIEW_N, DATASET_PREVIEW_SEED
s, sl = make_s_curve(n_samples=N, noise=0.03, random_state=SEED)
sr, srl = make_swiss_roll(n_samples=N, noise=0.15, random_state=SEED)
srh, srhl = make_swiss_roll(n_samples=N, noise=0.15, hole=True, random_state=SEED)
b, bl = make_blobs(
n_samples=N, n_features=3, centers=5, cluster_std=1.0, random_state=SEED
)
gq, gql = make_gaussian_quantiles(
n_samples=N, n_features=3, n_classes=4, random_state=SEED
)
cls, clsl = make_classification(
n_samples=N,
n_features=3,
n_informative=3,
n_redundant=0,
n_repeated=0,
n_classes=4,
n_clusters_per_class=2,
class_sep=1.5,
random_state=SEED,
)
samples = {
"s_curve": (s, sl),
"swiss_roll": (sr, srl),
"swiss_roll_hole": (srh, srhl),
"blobs": (b, bl),
"gaussian_quantiles": (gq, gql),
"classification": (cls, clsl),
}
out: Dict[str, Dict[str, Any]] = {}
for key, meta in DATASET_META.items():
pts, labels = samples[key]
out[key] = {
**meta,
"points": pts.tolist(),
"labels": labels.tolist(),
}
return out
# ---------------------------------------------------------------------------
# Reducer catalogue
# ---------------------------------------------------------------------------
# Each field tuple: (name, kind, default, choices_or_none, help_or_none)
# kinds: "int", "float", "str", "bool", "str_or_float", "int_or_null"
REDUCERS: Dict[str, Dict[str, Any]] = {
"sklearn.decomposition.PCA": {
"pkg": "sklearn",
"label": "PCA",
"blurb": "Principal component analysis. Linear, fast, deterministic.",
"key": [
("n_components", "int", 2, None, "Locked."),
],
"advanced": [
("svd_solver", "str", "auto", ["auto", "full", "arpack", "randomized"], None),
("random_state", "int", 42, None, None),
("whiten", "bool", False, None, None),
],
},
"sklearn.decomposition.FactorAnalysis": {
"pkg": "sklearn",
"label": "FactorAnalysis",
"blurb": "Gaussian latent-factor model with per-feature noise.",
"key": [
("n_components", "int", 2, None, "Locked."),
("random_state", "int", 42, None, None),
],
"advanced": [
("tol", "float", 0.01, None, None),
("max_iter", "int", 1000, None, None),
("rotation", "str", "", ["", "varimax", "quartimax"], "Empty = None."),
],
},
"sklearn.decomposition.KernelPCA": {
"pkg": "sklearn",
"label": "KernelPCA",
"blurb": "Non-linear PCA via the kernel trick. Deterministic; kernel choice shapes the output.",
"key": [
("n_components", "int", 2, None, "Locked."),
("kernel", "str", "rbf", ["linear", "poly", "rbf", "sigmoid", "cosine"], None),
("random_state", "int", 42, None, None),
],
"advanced": [
("gamma", "str_or_float", "", None, "Empty = 1/n_features."),
("degree", "int", 3, None, None),
("coef0", "float", 1.0, None, None),
("alpha", "float", 1.0, None, None),
],
},
"sklearn.manifold.Isomap": {
"pkg": "sklearn",
"label": "Isomap",
"blurb": "Geodesic-distance manifold learning via shortest paths on a k-NN graph.",
"key": [
("n_components", "int", 2, None, "Locked."),
("n_neighbors", "int", 5, None, None),
],
"advanced": [
("metric", "str", "minkowski", None, None),
("p", "int", 2, None, "Minkowski power (1 = manhattan, 2 = euclidean)."),
("path_method", "str", "auto", ["auto", "FW", "D"], "Floyd-Warshall / Dijkstra / auto."),
("neighbors_algorithm", "str", "auto", ["auto", "ball_tree", "kd_tree", "brute"], None),
],
},
"sklearn.manifold.MDS": {
"pkg": "sklearn",
"label": "MDS",
"blurb": "Multidimensional scaling. Preserves pairwise distances; O(n²) memory.",
"key": [
("n_components", "int", 2, None, "Locked."),
("n_init", "int", 4, None, None),
("random_state", "int", 42, None, None),
],
"advanced": [
("max_iter", "int", 300, None, None),
("metric_mds", "bool", True, None, "Metric (True) vs non-metric MDS."),
("metric", "str", "euclidean", None, None),
("eps", "float", 1e-6, None, "Convergence tolerance."),
],
},
"sklearn.manifold.SpectralEmbedding": {
"pkg": "sklearn",
"label": "SpectralEmbedding",
"blurb": "Laplacian eigenmaps on an affinity graph. What UMAP uses for initialisation.",
"key": [
("n_components", "int", 2, None, "Locked."),
("affinity", "str", "nearest_neighbors", ["nearest_neighbors", "rbf"], None),
("random_state", "int", 42, None, None),
],
"advanced": [
("n_neighbors", "int_or_null", "", None, "For affinity=nearest_neighbors. Empty = n/10."),
("gamma", "str_or_float", "", None, "For affinity=rbf. Empty = 1/n_features."),
],
},
"sklearn.manifold.TSNE": {
"pkg": "sklearn",
"label": "t-SNE",
"blurb": "Stochastic neighbour embedding. Local structure preserved.",
"key": [
("n_components", "int", 2, None, "Locked."),
("perplexity", "float", 30.0, None, None),
("random_state", "int", 42, None, None),
],
"advanced": [
("learning_rate", "str_or_float", "auto", None, "'auto' or a float."),
("n_iter", "int", 1000, None, None),
("metric", "str", "euclidean", None, None),
("early_exaggeration", "float", 12.0, None, None),
("init", "str", "pca", ["pca", "random"], None),
],
},
"umap.UMAP": {
"pkg": "umap",
"label": "UMAP",
"blurb": "Uniform manifold approximation. Preserves local + some global structure.",
"key": [
("n_components", "int", 2, None, "Locked."),
("n_neighbors", "int", 15, None, None),
("min_dist", "float", 0.1, None, None),
("random_state", "int", 42, None, None),
],
"advanced": [
("metric", "str", "euclidean", None, None),
("n_epochs", "int_or_null", "", None, "Empty = None (auto)."),
("spread", "float", 1.0, None, None),
("init", "str", "spectral", ["spectral", "random"], None),
],
},
"pacmap.PaCMAP": {
"pkg": "pacmap",
"label": "PaCMAP",
"blurb": "Pairwise-controlled manifold approximation. Balanced local/global.",
"key": [
("n_components", "int", 2, None, "Locked."),
("n_neighbors", "int", 10, None, None),
("MN_ratio", "float", 0.5, None, None),
("FP_ratio", "float", 2.0, None, None),
("random_state", "int", 42, None, None),
],
"advanced": [
("lr", "float", 1.0, None, None),
("num_iters", "int", 450, None, None),
("apply_pca", "bool", True, None, None),
],
},
"pacmap.LocalMAP": {
"pkg": "pacmap",
"label": "LocalMAP",
"blurb": "PaCMAP variant with a low-distance threshold; sharper local structure.",
"key": [
("n_components", "int", 2, None, "Locked."),
("n_neighbors", "int", 10, None, None),
("MN_ratio", "float", 0.5, None, None),
("FP_ratio", "float", 2.0, None, None),
("random_state", "int", 42, None, None),
],
"advanced": [
("lr", "float", 1.0, None, None),
("num_iters", "int", 450, None, None),
("apply_pca", "bool", True, None, None),
("low_dist_thres", "float", 10.0, None, None),
],
},
"trimap.TRIMAP": {
"pkg": "trimap",
"label": "TriMap",
"blurb": "Triplet-based dimensionality reduction. Emphasises global structure.",
"key": [
("n_dims", "int", 2, None, "Locked."),
("n_inliers", "int", 10, None, None),
("n_outliers", "int", 5, None, None),
("n_random", "int", 5, None, None),
],
"advanced": [
("lr", "float", 0.1, None, None),
("n_iters", "int", 400, None, None),
("weight_adj", "float", 500.0, None, None),
],
},
"sklearn.random_projection.GaussianRandomProjection": {
"pkg": "sklearn",
"label": "GaussianRandomProjection",
"blurb": "Johnson-Lindenstrauss baseline. Cheap, distance-preserving in expectation, structure-agnostic.",
"key": [
("n_components", "int", 2, None, "Locked."),
("random_state", "int", 42, None, None),
],
"advanced": [],
},
}
def available_reducers() -> List[Tuple[str, Dict[str, Any]]]:
out = []
for key, spec in REDUCERS.items():
if importlib.util.find_spec(spec["pkg"]) is not None:
out.append((key, spec))
return out
# ---------------------------------------------------------------------------
# Parameter coercion
# ---------------------------------------------------------------------------
def _coerce(kind: str, raw: Optional[str], default: Any) -> Any:
if raw is None:
return default
s = raw.strip() if isinstance(raw, str) else raw
if kind == "int":
if s == "" or s is None:
return default
return int(s)
if kind == "float":
if s == "" or s is None:
return default
return float(s)
if kind == "bool":
# Checkbox: "on" / absent
return bool(s) and s not in ("0", "false", "False", "")
if kind == "str":
if s == "":
return None if default in (None, "") else default if default else ""
return s
if kind == "str_or_float":
if s == "":
return default
try:
return float(s)
except (ValueError, TypeError):
return s
if kind == "int_or_null":
if s == "":
return None
return int(s)
return s
def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]:
spec = REDUCERS[reducer_key]
out: Dict[str, Any] = {}
all_fields = list(spec["key"]) + list(spec["advanced"])
for (name, kind, default, _choices, _help) in all_fields:
raw = form.get(f"embed__{name}")
if kind == "bool":
raw_v = "on" if f"embed__{name}" in form else ""
value = bool(raw_v)
else:
value = _coerce(kind, raw, default)
# Null-stripping: drop empty rotations etc.
if value is None:
continue
if isinstance(value, str) and value == "" and default in (None, ""):
continue
out[name] = value
# Always force n_components / n_dims to 2 (flow assertion)
if "n_components" in out:
out["n_components"] = 2
if "n_dims" in out:
out["n_dims"] = 2
return out
# ---------------------------------------------------------------------------
# Output-path synthesis (mirrors flows/embedding_flow.py lines ~162168)
# ---------------------------------------------------------------------------
def run_args_hash(
embed_args: Optional[Dict[str, Any]],
generator_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
"""8-hex digest of (embed_args, generator_kwargs). When generator_kwargs
is empty/None we hash embed_args alone — preserves stems for the plain
generators (s_curve, plain swiss_roll) that never had gen_kwargs. For
kwargs-bearing variants (swiss_roll_hole, blobs, gaussian_quantiles,
classification), the hash now disambiguates them from their kwargs-less
siblings — run scripts/backfill_hashes.py to rehash existing figs."""
if generator_kwargs:
payload: Any = {
"embed_args": embed_args or {},
"generator_kwargs": generator_kwargs,
}
else:
payload = embed_args or {}
s = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8]
# Back-compat alias — some call sites passed only embed_args.
embed_args_hash = run_args_hash
def synthesize_output_paths(
generator_path: str,
embedder: str,
num_points: int,
num_timesteps: int,
jitter_scale: float,
seed: int,
embed_args: Optional[Dict[str, Any]] = None,
generator_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[str, str]:
gen = generator_path.split(".")[-1]
emb = embedder.split(".")[-1]
ref = f"{gen}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
base = f"{gen}_{emb}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}"
if embed_args is None:
embf = f"{base}.html"
else:
embf = f"{base}_{run_args_hash(embed_args, generator_kwargs)}.html"
return ref, embf
def _resolve_emb_file(synthesized: str) -> str:
"""Disk-backed fallback: prefer the synthesized (hashed) name; if that
doesn't exist on disk but an older hash-less variant does, return that
so pre-hash runs still render in the UI."""
if (FIGS_DIR / synthesized).exists():
return synthesized
# Strip trailing _<8hex>.html to get the legacy name.
m = re.match(r"^(?P<base>.+)_[0-9a-f]{8}\.html$", synthesized)
if m:
legacy = m.group("base") + ".html"
if (FIGS_DIR / legacy).exists():
return legacy
return synthesized # new / still-running run; let emb_exists resolve
# ---------------------------------------------------------------------------
# Prefect client
# ---------------------------------------------------------------------------
class Prefect:
def __init__(self, base: str = PREFECT_API) -> None:
self.base = base.rstrip("/")
self._deployment_id: Optional[str] = None
async def deployment_id(self, client: httpx.AsyncClient) -> Optional[str]:
if self._deployment_id:
return self._deployment_id
try:
r = await client.get(f"{self.base}/deployments/name/{DEPLOYMENT_NAME}")
if r.status_code == 200:
self._deployment_id = r.json()["id"]
return self._deployment_id
except httpx.HTTPError:
return None
return None
async def create_run(
self, client: httpx.AsyncClient, parameters: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
dep = await self.deployment_id(client)
if not dep:
return None
r = await client.post(
f"{self.base}/deployments/{dep}/create_flow_run",
json={"parameters": parameters},
)
if r.status_code >= 400:
return {"error": r.text, "status": r.status_code}
return r.json()
async def recent_runs(
self, client: httpx.AsyncClient, limit: int = 10
) -> List[Dict[str, Any]]:
dep = await self.deployment_id(client)
if not dep:
return []
try:
r = await client.post(
f"{self.base}/flow_runs/filter",
json={
"sort": "START_TIME_DESC",
"limit": limit,
"flow_runs": {"deployment_id": {"any_": [dep]}},
},
)
if r.status_code == 200:
return r.json()
except httpx.HTTPError:
return []
return []
PREFECT = Prefect()
# ---------------------------------------------------------------------------
# In-memory mapping: flow_run_id -> synthesized output file names
# (best-effort; lost on restart, which is fine)
# ---------------------------------------------------------------------------
RUN_OUTPUTS: Dict[str, Dict[str, str]] = {}
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(title="web1 — embedding notebook", docs_url=None, redoc_url=None)
app.mount("/figs", StaticFiles(directory=str(FIGS_DIR)), name="figs")
app.mount("/static", StaticFiles(directory=str(BASE_DIR / "static")), name="static")
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
def _fmt_runtime(seconds: Optional[float]) -> Optional[str]:
if seconds is None or seconds <= 0:
return None
if seconds < 60:
return f"{seconds:.1f}s"
m, s = divmod(int(seconds), 60)
if m < 60:
return f"{m}m{s:02d}s"
h, m = divmod(m, 60)
return f"{h}h{m:02d}m"
def _run_view(run: Dict[str, Any]) -> Dict[str, Any]:
"""Normalise a flow-run dict for the template."""
rid = run.get("id", "")
state_type = (run.get("state_type") or "PENDING").upper()
state_name = run.get("state_name") or state_type.title()
# Prefect labels a scheduled run that's past its start time "Late". The
# runner is simply busy / awaiting its turn at the concurrency cap — show
# "Queued" instead, which is the honest description of the state.
if state_name == "Late":
state_name = "Queued"
start = run.get("start_time") or run.get("expected_start_time") or run.get("created")
params = run.get("parameters") or {}
# estimated_run_time recomputes server-side (ticks up while RUNNING);
# total_run_time is the authoritative value once the run finishes.
rtime_s = run.get("estimated_run_time") or run.get("total_run_time")
# Try to look up synthesised outputs either from memory or from params
ref_file = None
emb_file = None
outs = RUN_OUTPUTS.get(rid)
if outs:
ref_file = outs["ref"]
emb_file = outs["embed"]
elif params:
try:
ref_file, emb_file = synthesize_output_paths(
params.get("generator_path", "sklearn.datasets.make_s_curve"),
params.get("embedder", "sklearn.decomposition.FactorAnalysis"),
int(params.get("num_points", 5000)),
# Fallback to the old num_snapshots key for runs dispatched
# before the T-rename, so historical figs still resolve after
# `rename 's/_S/_T/' figs/*.html`.
int(params.get("num_timesteps", params.get("num_snapshots", 48))),
float(params.get("jitter_scale", 0.01)),
int(params.get("seed", 42)),
embed_args=params.get("embed_args") or {},
generator_kwargs=params.get("generator_kwargs") or {},
)
# Older runs may lack the hash suffix; prefer legacy name on disk.
emb_file = _resolve_emb_file(emb_file)
except Exception:
ref_file, emb_file = None, None
ref_exists = bool(ref_file) and (FIGS_DIR / ref_file).exists()
emb_exists = bool(emb_file) and (FIGS_DIR / emb_file).exists()
return {
"id": rid,
"short_id": rid[:8] if rid else "",
"name": run.get("name", ""),
"state_type": state_type,
"state_name": state_name,
"start": start,
"runtime": _fmt_runtime(float(rtime_s) if rtime_s is not None else None),
"params": params,
"ref_file": ref_file,
"emb_file": emb_file,
"ref_exists": ref_exists,
"emb_exists": emb_exists,
"embedder_short": (params.get("embedder") or "").split(".")[-1],
"generator_short": (params.get("generator_path") or "").split(".")[-1],
}
def _mark_stale_views(views: List[Dict[str, Any]]) -> None:
"""Flag runs whose emb HTML was overwritten by a newer sibling with
identical params. Prefect returns runs sorted by START_TIME_DESC, so the
first occurrence of each stem is authoritative; later ones are stale.
Mutates views in place."""
seen: set = set()
for v in views:
stem = v["emb_file"][:-5] if v.get("emb_file") else None
if stem and stem in seen:
v["stale"] = True
else:
v["stale"] = False
if stem:
seen.add(stem)
def _reducer_choices() -> List[Dict[str, str]]:
return [
{"key": k, "label": spec["label"], "blurb": spec["blurb"]}
for k, spec in available_reducers()
]
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def index(request: Request) -> HTMLResponse:
reducers = _reducer_choices()
default_reducer = reducers[0]["key"] if reducers else None
default_spec = REDUCERS.get(default_reducer) if default_reducer else None
async with httpx.AsyncClient(timeout=5.0) as client:
runs = await PREFECT.recent_runs(client, limit=10)
dep_id = await PREFECT.deployment_id(client)
views = [_run_view(r) for r in runs]
_mark_stale_views(views)
return templates.TemplateResponse(
request,
"index.html",
{
"reducers": reducers,
"default_reducer": default_reducer,
"default_spec": default_spec,
"runs": views,
"deployment_id": dep_id,
"prefect_api": PREFECT_API,
},
)
@app.get("/data.json")
async def data_json() -> JSONResponse:
return JSONResponse(_dataset_previews())
@app.get("/reducer-form", response_class=HTMLResponse)
async def reducer_form(request: Request, name: str) -> HTMLResponse:
spec = REDUCERS.get(name)
if not spec:
return HTMLResponse("<p class='err'>unknown reducer</p>", status_code=404)
return templates.TemplateResponse(
request,
"_reducer_form.html",
{"reducer_key": name, "spec": spec},
)
@app.get("/runs", response_class=HTMLResponse)
async def runs_partial(request: Request) -> HTMLResponse:
async with httpx.AsyncClient(timeout=5.0) as client:
runs = await PREFECT.recent_runs(client, limit=10)
views = [_run_view(r) for r in runs]
_mark_stale_views(views)
return templates.TemplateResponse(
request, "_runs.html", {"runs": views}
)
@app.post("/submit", response_class=HTMLResponse)
async def submit(request: Request) -> HTMLResponse:
form = await request.form()
data: Dict[str, str] = {k: str(v) for k, v in form.items()}
reducer = data.get("reducer") or ""
if reducer not in REDUCERS:
return HTMLResponse(
f"<div class='flash err'>unknown reducer: {reducer}</div>",
status_code=400,
)
# Dataset came from the picker via dataset_id; fall back to explicit
# generator_path / generator_kwargs only when dataset_id is absent entirely
# (API consumers). UI form posts always carry the key, so an empty value
# means the user hit submit without picking — reject rather than silently
# defaulting to s_curve.
if "dataset_id" in data:
dataset_id = data.get("dataset_id") or ""
if not dataset_id:
return HTMLResponse(
"<div class='flash err'>pick a dataset first (§ 1 above)</div>",
status_code=400,
)
if dataset_id not in DATASET_META:
return HTMLResponse(
f"<div class='flash err'>unknown dataset: {dataset_id}</div>",
status_code=400,
)
meta = DATASET_META[dataset_id]
generator_path = meta["path"]
generator_kwargs = dict(meta["kwargs"])
else:
generator_path = data.get("generator_path") or ""
if not generator_path:
return HTMLResponse(
"<div class='flash err'>missing dataset_id or generator_path</div>",
status_code=400,
)
raw_kwargs = data.get("generator_kwargs") or ""
try:
generator_kwargs = json.loads(raw_kwargs) if raw_kwargs else {}
except json.JSONDecodeError as e:
return HTMLResponse(
f"<div class='flash err'>bad generator_kwargs JSON: {e}</div>",
status_code=400,
)
try:
num_points = int(data.get("num_points", "5000") or 5000)
num_timesteps = int(data.get("num_timesteps", "48") or 48)
jitter_scale = float(data.get("jitter_scale", "0.01") or 0.01)
seed = int(data.get("seed", "42") or 42)
except ValueError as e:
return HTMLResponse(
f"<div class='flash err'>bad numeric input: {e}</div>", status_code=400
)
embed_args = build_embed_args(reducer, data)
# Reject submissions whose output path would overwrite an existing fig.
# Hash now covers both embed_args and generator_kwargs, so swiss_roll vs
# swiss_roll_hole (and blobs with varying n_features, etc.) no longer
# share a stem. Also check the legacy hashless path for pre-hash figs.
_, hashed_emb = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
embed_args=embed_args, generator_kwargs=generator_kwargs,
)
_, legacy_emb = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
)
for candidate in (hashed_emb, legacy_emb):
if (FIGS_DIR / candidate).exists():
return HTMLResponse(
f"<div class='flash err'>a run with matching params already "
f"exists (<code>{candidate}</code>). change a param or delete "
f"the fig first.</div>",
status_code=409,
)
parameters: Dict[str, Any] = {
"num_points": num_points,
"num_timesteps": num_timesteps,
"jitter_scale": jitter_scale,
"seed": seed,
"generator_path": generator_path,
"embedder": reducer,
"embed_args": embed_args,
}
if generator_kwargs:
parameters["generator_kwargs"] = generator_kwargs
async with httpx.AsyncClient(timeout=10.0) as client:
run = await PREFECT.create_run(client, parameters)
if not run:
return HTMLResponse(
"<div class='flash err'>could not reach Prefect API at "
f"{PREFECT_API}</div>",
status_code=502,
)
if "error" in run:
return HTMLResponse(
f"<div class='flash err'>prefect error ({run.get('status')}): "
f"<code>{run.get('error')[:500]}</code></div>",
status_code=502,
)
ref_file, emb_file = synthesize_output_paths(
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed,
embed_args=embed_args, generator_kwargs=generator_kwargs,
)
RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file}
# Return freshly refreshed runs partial so htmx can swap the right column
async with httpx.AsyncClient(timeout=5.0) as client:
runs = await PREFECT.recent_runs(client, limit=10)
views = [_run_view(r) for r in runs]
_mark_stale_views(views)
return templates.TemplateResponse(
request,
"_runs.html",
{"runs": views, "just_submitted": run["id"]},
)
def _scan_metrics() -> List[Dict[str, Any]]:
"""Read every `*.metrics.json` in FIGS_DIR and return them as a list."""
out: List[Dict[str, Any]] = []
for p in sorted(FIGS_DIR.glob("*.metrics.json"), key=lambda p: p.stat().st_mtime, reverse=True):
try:
data = json.loads(p.read_text())
except (OSError, json.JSONDecodeError):
continue
data["filename"] = p.name
data["embedding_file"] = p.name.replace(".metrics.json", ".html")
out.append(data)
return out
@app.get("/metrics", response_class=HTMLResponse)
async def metrics_page(request: Request) -> HTMLResponse:
async with httpx.AsyncClient(timeout=5.0) as client:
dep_id = await PREFECT.deployment_id(client)
return templates.TemplateResponse(
request,
"metrics.html",
{"prefect_api": PREFECT_API, "deployment_id": dep_id},
)
@app.get("/metrics.json")
async def metrics_json() -> JSONResponse:
return JSONResponse(_scan_metrics())
_STEM_RE = re.compile(
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+(?:_[0-9a-f]{8})?$"
)
# Map short generator name ("make_blobs") to its DATASET_META entry.
# swiss_roll and swiss_roll_hole collide on path; first wins (plain variant).
_GEN_TO_META: Dict[str, Dict[str, Any]] = {}
for _m in DATASET_META.values():
_GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m)
def _lookup_dataset_meta(
generator_short: str, generator_kwargs: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Match DATASET_META by generator short-name AND kwargs when available.
Falls back to first-wins when kwargs are unknown (ambiguous for
swiss_roll vs swiss_roll_hole — both share `make_swiss_roll`)."""
candidates = [
m for m in DATASET_META.values()
if m["path"].rsplit(".", 1)[-1] == generator_short
]
if not candidates:
return None
if generator_kwargs is not None:
for m in candidates:
if m["kwargs"] == generator_kwargs:
return m
return candidates[0]
def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
"""Attach per-point class/continuous labels by regenerating the dataset
with the same (generator, n_samples, kwargs). random_state is fixed at 0
(the flow's _DEFAULT_GENERATOR_KWARGS) — the stem's `seed` drives jitter,
not the generator. Jitter-added points (id >= num_points) get None so
the client renders them as black.
Discovers generator_kwargs in priority order: (1) payload meta (sidecar
runs from the updated flow); (2) sibling metrics.json; (3) DATASET_META
by first-match (ambiguous for swiss_roll/swiss_roll_hole — need a
backfilled metrics.json to disambiguate)."""
meta = d.get("meta") or {}
gen_short = meta.get("generator") or ""
gk = meta.get("generator_kwargs")
if gk is None:
stem = meta.get("stem")
if stem:
mx = FIGS_DIR / f"{stem}.metrics.json"
if mx.is_file():
try:
gk = json.loads(mx.read_text(encoding="utf-8")).get(
"meta", {}
).get("generator_kwargs")
except Exception:
gk = None
dm = _lookup_dataset_meta(gen_short, gk)
if not dm:
return d
kwargs_to_use = gk if gk is not None else dm["kwargs"]
try:
mod_path, cls_name = dm["path"].rsplit(".", 1)
fn = getattr(importlib.import_module(mod_path), cls_name)
N = int(meta["num_points"])
_, gen_labels = fn(n_samples=N, random_state=0, **kwargs_to_use)
out_labels: List[Optional[float]] = []
for pid in d["point_ids"]:
if isinstance(pid, int) and 0 <= pid < N:
v = gen_labels[pid]
out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None)
else:
out_labels.append(None)
d["labels"] = out_labels
d["label_kind"] = dm["kind"]
except Exception:
pass
return d
@lru_cache(maxsize=32)
def _cached_frames(stem: str) -> str:
"""Return the frames dict as a JSON string. Prefers a <stem>.frames.json
sidecar (emitted by new flow runs); falls back to parsing <stem>.html
(for pre-sidecar runs). Either way, enriches with dataset labels."""
sidecar = FIGS_DIR / f"{stem}.frames.json"
if sidecar.is_file():
d = json.loads(sidecar.read_text(encoding="utf-8"))
else:
html = FIGS_DIR / f"{stem}.html"
d = parse_plotly_run(html)
# Override meta.stem with the URL-requested stem — after a backfill the
# file was renamed but the baked-in meta.stem still points at the old
# name. Enrichment uses this to find the sibling metrics.json.
d.setdefault("meta", {})["stem"] = stem
d = _enrich_with_labels(d)
return json.dumps(d, separators=(",", ":"))
@app.get("/api/runs/{stem}/frames.json")
async def run_frames(stem: str) -> Response:
if not _STEM_RE.match(stem):
raise HTTPException(400, f"malformed stem: {stem!r}")
if not (FIGS_DIR / f"{stem}.frames.json").is_file() and not (FIGS_DIR / f"{stem}.html").is_file():
raise HTTPException(404, f"no such run: {stem}")
try:
payload = _cached_frames(stem)
except Exception as e:
raise HTTPException(500, f"parse failed: {e}")
return Response(
content=payload,
media_type="application/json",
headers={"Cache-Control": "no-cache"},
)
@app.get("/compare", response_class=HTMLResponse)
async def compare_page(request: Request) -> HTMLResponse:
q = request.query_params
stems = [s for s in q.getlist("stem") if s]
if not stems:
# Legacy two-stem form: ?a=&b=
stems = [s for s in (q.get("a", ""), q.get("b", "")) if s]
if not (2 <= len(stems) <= 8):
raise HTTPException(400, f"need 2..8 stems, got {len(stems)}")
for stem in stems:
if not _STEM_RE.match(stem):
raise HTTPException(400, f"malformed stem: {stem!r}")
has_sidecar = (FIGS_DIR / f"{stem}.frames.json").is_file()
has_html = (FIGS_DIR / f"{stem}.html").is_file()
if not (has_sidecar or has_html):
raise HTTPException(404, f"no such run: {stem}")
return templates.TemplateResponse(
request, "compare.html", {"stems": stems}
)
@app.get("/health")
async def health() -> JSONResponse:
async with httpx.AsyncClient(timeout=3.0) as client:
dep = await PREFECT.deployment_id(client)
return JSONResponse(
{"ok": True, "deployment_id": dep, "prefect_api": PREFECT_API}
)