""" web1 — "Scientific instrument / research notebook" A FastAPI UI for kicking off the embedding-flow Prefect deployment and viewing the resulting HTML animations. Design: restrained, typography-driven, two-column notebook layout. No CSS framework; hand-written styles. """ from __future__ import annotations import importlib.util import json import os from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import httpx from fastapi import FastAPI, Form, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sklearn.datasets import ( make_blobs, make_classification, make_gaussian_quantiles, make_s_curve, make_swiss_roll, ) # --------------------------------------------------------------------------- # Paths / constants # --------------------------------------------------------------------------- BASE_DIR = Path(__file__).resolve().parent PROJECT_ROOT = BASE_DIR.parent.parent # /home/mm/work/dr-sandbox FIGS_DIR = PROJECT_ROOT / "figs" FIGS_DIR.mkdir(parents=True, exist_ok=True) PREFECT_API = os.environ.get("PREFECT_API_URL", "http://localhost:4200/api") DEPLOYMENT_NAME = "embedding-flow/embedding-flow" # --------------------------------------------------------------------------- # Dataset catalogue # --------------------------------------------------------------------------- # Metadata for the /data.json endpoint consumed by the dataset picker, and # for server-side lookup when the picker posts its selection back. kwargs # must carry n_features=3 for generators that aren't already 3-D, since # they'll be forwarded verbatim to the Prefect flow's generator_kwargs. DATASET_PREVIEW_N = 5000 DATASET_PREVIEW_SEED = 0 DATASET_META: Dict[str, Dict[str, Any]] = { "s_curve": { "name": "S-Curve", "path": "sklearn.datasets.make_s_curve", "kwargs": {}, "description": ( "A 2-D manifold warped into R³. Continuous label encodes position " "along the curve — a good test of whether a reducer unrolls the " "sheet without tearing." ), "kind": "continuous", }, "swiss_roll": { "name": "Swiss Roll", "path": "sklearn.datasets.make_swiss_roll", "kwargs": {}, "description": ( "A rolled-up plane. The canonical hard case for linear methods: " "PCA collapses the spiral, non-linear methods should recover the " "unroll." ), "kind": "continuous", }, "swiss_roll_hole": { "name": "Swiss Roll (hole)", "path": "sklearn.datasets.make_swiss_roll", "kwargs": {"hole": True}, "description": ( "Swiss roll with a rectangular hole punched through. Same manifold, " "non-trivial topology — a faithful unroll should preserve the hole " "rather than smearing it closed." ), "kind": "continuous", }, "blobs": { "name": "Gaussian Blobs", "path": "sklearn.datasets.make_blobs", "kwargs": {"n_features": 3, "centers": 5, "cluster_std": 1.0}, "description": ( "Five isotropic Gaussian clusters in R³. Discrete class labels. " "Tests whether a reducer preserves cluster separation when " "projected to 2-D." ), "kind": "categorical", }, "gaussian_quantiles": { "name": "Gaussian Quantiles", "path": "sklearn.datasets.make_gaussian_quantiles", "kwargs": {"n_features": 3, "n_classes": 4}, "description": ( "Concentric Gaussian shells in R³; class = which shell. Classes " "are linearly inseparable by construction — PCA collapses them, " "kernel and manifold methods have a chance." ), "kind": "categorical", }, "classification": { "name": "Hypercube Clusters", "path": "sklearn.datasets.make_classification", "kwargs": { "n_features": 3, "n_informative": 3, "n_redundant": 0, "n_repeated": 0, "n_classes": 4, "n_clusters_per_class": 2, "class_sep": 1.5, }, "description": ( "Four classes, two sub-clusters each, placed at hypercube vertices " "with informative noise. A denser discrete test than blobs — " "within-class bimodality stresses cluster-preserving reducers." ), "kind": "categorical", }, } @lru_cache(maxsize=1) def _dataset_previews() -> Dict[str, Dict[str, Any]]: """Attach freshly-generated points+labels to the catalogue for the picker.""" N, SEED = DATASET_PREVIEW_N, DATASET_PREVIEW_SEED s, sl = make_s_curve(n_samples=N, noise=0.03, random_state=SEED) sr, srl = make_swiss_roll(n_samples=N, noise=0.15, random_state=SEED) srh, srhl = make_swiss_roll(n_samples=N, noise=0.15, hole=True, random_state=SEED) b, bl = make_blobs( n_samples=N, n_features=3, centers=5, cluster_std=1.0, random_state=SEED ) gq, gql = make_gaussian_quantiles( n_samples=N, n_features=3, n_classes=4, random_state=SEED ) cls, clsl = make_classification( n_samples=N, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=4, n_clusters_per_class=2, class_sep=1.5, random_state=SEED, ) samples = { "s_curve": (s, sl), "swiss_roll": (sr, srl), "swiss_roll_hole": (srh, srhl), "blobs": (b, bl), "gaussian_quantiles": (gq, gql), "classification": (cls, clsl), } out: Dict[str, Dict[str, Any]] = {} for key, meta in DATASET_META.items(): pts, labels = samples[key] out[key] = { **meta, "points": pts.tolist(), "labels": labels.tolist(), } return out # --------------------------------------------------------------------------- # Reducer catalogue # --------------------------------------------------------------------------- # Each field tuple: (name, kind, default, choices_or_none, help_or_none) # kinds: "int", "float", "str", "bool", "str_or_float", "int_or_null" REDUCERS: Dict[str, Dict[str, Any]] = { "sklearn.decomposition.PCA": { "pkg": "sklearn", "label": "PCA", "blurb": "Principal component analysis. Linear, fast, deterministic.", "key": [ ("n_components", "int", 2, None, "Locked."), ], "advanced": [ ("svd_solver", "str", "auto", ["auto", "full", "arpack", "randomized"], None), ("random_state", "int", 42, None, None), ("whiten", "bool", False, None, None), ], }, "sklearn.decomposition.FactorAnalysis": { "pkg": "sklearn", "label": "FactorAnalysis", "blurb": "Gaussian latent-factor model with per-feature noise.", "key": [ ("n_components", "int", 2, None, "Locked."), ("random_state", "int", 42, None, None), ], "advanced": [ ("tol", "float", 0.01, None, None), ("max_iter", "int", 1000, None, None), ("rotation", "str", "", ["", "varimax", "quartimax"], "Empty = None."), ], }, "sklearn.manifold.TSNE": { "pkg": "sklearn", "label": "t-SNE", "blurb": "Stochastic neighbour embedding. Local structure preserved.", "key": [ ("n_components", "int", 2, None, "Locked."), ("perplexity", "float", 30.0, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("learning_rate", "str_or_float", "auto", None, "'auto' or a float."), ("n_iter", "int", 1000, None, None), ("metric", "str", "euclidean", None, None), ("early_exaggeration", "float", 12.0, None, None), ("init", "str", "pca", ["pca", "random"], None), ], }, "umap.UMAP": { "pkg": "umap", "label": "UMAP", "blurb": "Uniform manifold approximation. Preserves local + some global structure.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 15, None, None), ("min_dist", "float", 0.1, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("metric", "str", "euclidean", None, None), ("n_epochs", "int_or_null", "", None, "Empty = None (auto)."), ("spread", "float", 1.0, None, None), ("init", "str", "spectral", ["spectral", "random"], None), ], }, "pacmap.PaCMAP": { "pkg": "pacmap", "label": "PaCMAP", "blurb": "Pairwise-controlled manifold approximation. Balanced local/global.", "key": [ ("n_components", "int", 2, None, "Locked."), ("n_neighbors", "int", 10, None, None), ("MN_ratio", "float", 0.5, None, None), ("FP_ratio", "float", 2.0, None, None), ("random_state", "int", 42, None, None), ], "advanced": [ ("lr", "float", 1.0, None, None), ("num_iters", "int", 450, None, None), ("apply_pca", "bool", True, None, None), ], }, "trimap.TRIMAP": { "pkg": "trimap", "label": "TriMap", "blurb": "Triplet-based dimensionality reduction. Emphasises global structure.", "key": [ ("n_dims", "int", 2, None, "Locked."), ("n_inliers", "int", 10, None, None), ("n_outliers", "int", 5, None, None), ("n_random", "int", 5, None, None), ], "advanced": [ ("lr", "float", 0.1, None, None), ("n_iters", "int", 400, None, None), ("weight_adj", "float", 500.0, None, None), ], }, } def available_reducers() -> List[Tuple[str, Dict[str, Any]]]: out = [] for key, spec in REDUCERS.items(): if importlib.util.find_spec(spec["pkg"]) is not None: out.append((key, spec)) return out # --------------------------------------------------------------------------- # Parameter coercion # --------------------------------------------------------------------------- def _coerce(kind: str, raw: Optional[str], default: Any) -> Any: if raw is None: return default s = raw.strip() if isinstance(raw, str) else raw if kind == "int": if s == "" or s is None: return default return int(s) if kind == "float": if s == "" or s is None: return default return float(s) if kind == "bool": # Checkbox: "on" / absent return bool(s) and s not in ("0", "false", "False", "") if kind == "str": if s == "": return None if default in (None, "") else default if default else "" return s if kind == "str_or_float": if s == "": return default try: return float(s) except (ValueError, TypeError): return s if kind == "int_or_null": if s == "": return None return int(s) return s def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]: spec = REDUCERS[reducer_key] out: Dict[str, Any] = {} all_fields = list(spec["key"]) + list(spec["advanced"]) for (name, kind, default, _choices, _help) in all_fields: raw = form.get(f"embed__{name}") if kind == "bool": raw_v = "on" if f"embed__{name}" in form else "" value = bool(raw_v) else: value = _coerce(kind, raw, default) # Null-stripping: drop empty rotations etc. if value is None: continue if isinstance(value, str) and value == "" and default in (None, ""): continue out[name] = value # Always force n_components / n_dims to 2 (flow assertion) if "n_components" in out: out["n_components"] = 2 if "n_dims" in out: out["n_dims"] = 2 return out # --------------------------------------------------------------------------- # Output-path synthesis (mirrors flows/embedding_flow.py lines ~162–168) # --------------------------------------------------------------------------- def synthesize_output_paths( generator_path: str, embedder: str, num_points: int, num_timesteps: int, jitter_scale: float, seed: int, ) -> Tuple[str, str]: gen = generator_path.split(".")[-1] emb = embedder.split(".")[-1] ref = f"{gen}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" embf = f"{gen}_{emb}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" return ref, embf # --------------------------------------------------------------------------- # Prefect client # --------------------------------------------------------------------------- class Prefect: def __init__(self, base: str = PREFECT_API) -> None: self.base = base.rstrip("/") self._deployment_id: Optional[str] = None async def deployment_id(self, client: httpx.AsyncClient) -> Optional[str]: if self._deployment_id: return self._deployment_id try: r = await client.get(f"{self.base}/deployments/name/{DEPLOYMENT_NAME}") if r.status_code == 200: self._deployment_id = r.json()["id"] return self._deployment_id except httpx.HTTPError: return None return None async def create_run( self, client: httpx.AsyncClient, parameters: Dict[str, Any] ) -> Optional[Dict[str, Any]]: dep = await self.deployment_id(client) if not dep: return None r = await client.post( f"{self.base}/deployments/{dep}/create_flow_run", json={"parameters": parameters}, ) if r.status_code >= 400: return {"error": r.text, "status": r.status_code} return r.json() async def recent_runs( self, client: httpx.AsyncClient, limit: int = 10 ) -> List[Dict[str, Any]]: dep = await self.deployment_id(client) if not dep: return [] try: r = await client.post( f"{self.base}/flow_runs/filter", json={ "sort": "START_TIME_DESC", "limit": limit, "flow_runs": {"deployment_id": {"any_": [dep]}}, }, ) if r.status_code == 200: return r.json() except httpx.HTTPError: return [] return [] PREFECT = Prefect() # --------------------------------------------------------------------------- # In-memory mapping: flow_run_id -> synthesized output file names # (best-effort; lost on restart, which is fine) # --------------------------------------------------------------------------- RUN_OUTPUTS: Dict[str, Dict[str, str]] = {} # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI(title="web1 — embedding notebook", docs_url=None, redoc_url=None) app.mount("/figs", StaticFiles(directory=str(FIGS_DIR)), name="figs") app.mount("/static", StaticFiles(directory=str(BASE_DIR / "static")), name="static") templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) def _fmt_runtime(seconds: Optional[float]) -> Optional[str]: if seconds is None or seconds <= 0: return None if seconds < 60: return f"{seconds:.1f}s" m, s = divmod(int(seconds), 60) if m < 60: return f"{m}m{s:02d}s" h, m = divmod(m, 60) return f"{h}h{m:02d}m" def _run_view(run: Dict[str, Any]) -> Dict[str, Any]: """Normalise a flow-run dict for the template.""" rid = run.get("id", "") state_type = (run.get("state_type") or "PENDING").upper() state_name = run.get("state_name") or state_type.title() start = run.get("start_time") or run.get("expected_start_time") or run.get("created") params = run.get("parameters") or {} # estimated_run_time recomputes server-side (ticks up while RUNNING); # total_run_time is the authoritative value once the run finishes. rtime_s = run.get("estimated_run_time") or run.get("total_run_time") # Try to look up synthesised outputs either from memory or from params ref_file = None emb_file = None outs = RUN_OUTPUTS.get(rid) if outs: ref_file = outs["ref"] emb_file = outs["embed"] elif params: try: ref_file, emb_file = synthesize_output_paths( params.get("generator_path", "sklearn.datasets.make_s_curve"), params.get("embedder", "sklearn.decomposition.FactorAnalysis"), int(params.get("num_points", 5000)), # Fallback to the old num_snapshots key for runs dispatched # before the T-rename, so historical figs still resolve after # `rename 's/_S/_T/' figs/*.html`. int(params.get("num_timesteps", params.get("num_snapshots", 48))), float(params.get("jitter_scale", 0.01)), int(params.get("seed", 42)), ) except Exception: ref_file, emb_file = None, None ref_exists = bool(ref_file) and (FIGS_DIR / ref_file).exists() emb_exists = bool(emb_file) and (FIGS_DIR / emb_file).exists() return { "id": rid, "short_id": rid[:8] if rid else "", "name": run.get("name", ""), "state_type": state_type, "state_name": state_name, "start": start, "runtime": _fmt_runtime(float(rtime_s) if rtime_s is not None else None), "params": params, "ref_file": ref_file, "emb_file": emb_file, "ref_exists": ref_exists, "emb_exists": emb_exists, "embedder_short": (params.get("embedder") or "").split(".")[-1], "generator_short": (params.get("generator_path") or "").split(".")[-1], } def _reducer_choices() -> List[Dict[str, str]]: return [ {"key": k, "label": spec["label"], "blurb": spec["blurb"]} for k, spec in available_reducers() ] # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def index(request: Request) -> HTMLResponse: reducers = _reducer_choices() default_reducer = reducers[0]["key"] if reducers else None default_spec = REDUCERS.get(default_reducer) if default_reducer else None async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=10) dep_id = await PREFECT.deployment_id(client) views = [_run_view(r) for r in runs] return templates.TemplateResponse( request, "index.html", { "reducers": reducers, "default_reducer": default_reducer, "default_spec": default_spec, "runs": views, "deployment_id": dep_id, "prefect_api": PREFECT_API, }, ) @app.get("/data.json") async def data_json() -> JSONResponse: return JSONResponse(_dataset_previews()) @app.get("/reducer-form", response_class=HTMLResponse) async def reducer_form(request: Request, name: str) -> HTMLResponse: spec = REDUCERS.get(name) if not spec: return HTMLResponse("
unknown reducer
", status_code=404) return templates.TemplateResponse( request, "_reducer_form.html", {"reducer_key": name, "spec": spec}, ) @app.get("/runs", response_class=HTMLResponse) async def runs_partial(request: Request) -> HTMLResponse: async with httpx.AsyncClient(timeout=5.0) as client: runs = await PREFECT.recent_runs(client, limit=10) views = [_run_view(r) for r in runs] return templates.TemplateResponse( request, "_runs.html", {"runs": views} ) @app.post("/submit", response_class=HTMLResponse) async def submit(request: Request) -> HTMLResponse: form = await request.form() data: Dict[str, str] = {k: str(v) for k, v in form.items()} reducer = data.get("reducer") or "" if reducer not in REDUCERS: return HTMLResponse( f"{run.get('error')[:500]}