705 lines
24 KiB
Python
705 lines
24 KiB
Python
"""
|
||
web1 — "Scientific instrument / research notebook"
|
||
|
||
A FastAPI UI for kicking off the embedding-flow Prefect deployment and
|
||
viewing the resulting HTML animations.
|
||
|
||
Design: restrained, typography-driven, two-column notebook layout. No CSS
|
||
framework; hand-written styles.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import importlib.util
|
||
import json
|
||
import os
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import httpx
|
||
from fastapi import FastAPI, Form, Request
|
||
from fastapi.responses import HTMLResponse, JSONResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
from sklearn.datasets import (
|
||
make_blobs,
|
||
make_classification,
|
||
make_gaussian_quantiles,
|
||
make_s_curve,
|
||
make_swiss_roll,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Paths / constants
|
||
# ---------------------------------------------------------------------------
|
||
|
||
BASE_DIR = Path(__file__).resolve().parent
|
||
PROJECT_ROOT = BASE_DIR.parent.parent # /home/mm/work/dr-sandbox
|
||
FIGS_DIR = PROJECT_ROOT / "figs"
|
||
FIGS_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
PREFECT_API = os.environ.get("PREFECT_API_URL", "http://localhost:4200/api")
|
||
DEPLOYMENT_NAME = "embedding-flow/embedding-flow"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dataset catalogue
|
||
# ---------------------------------------------------------------------------
|
||
# Metadata for the /data.json endpoint consumed by the dataset picker, and
|
||
# for server-side lookup when the picker posts its selection back. kwargs
|
||
# must carry n_features=3 for generators that aren't already 3-D, since
|
||
# they'll be forwarded verbatim to the Prefect flow's generator_kwargs.
|
||
|
||
DATASET_PREVIEW_N = 5000
|
||
DATASET_PREVIEW_SEED = 0
|
||
|
||
DATASET_META: Dict[str, Dict[str, Any]] = {
|
||
"s_curve": {
|
||
"name": "S-Curve",
|
||
"path": "sklearn.datasets.make_s_curve",
|
||
"kwargs": {},
|
||
"description": (
|
||
"A 2-D manifold warped into R³. Continuous label encodes position "
|
||
"along the curve — a good test of whether a reducer unrolls the "
|
||
"sheet without tearing."
|
||
),
|
||
"kind": "continuous",
|
||
},
|
||
"swiss_roll": {
|
||
"name": "Swiss Roll",
|
||
"path": "sklearn.datasets.make_swiss_roll",
|
||
"kwargs": {},
|
||
"description": (
|
||
"A rolled-up plane. The canonical hard case for linear methods: "
|
||
"PCA collapses the spiral, non-linear methods should recover the "
|
||
"unroll."
|
||
),
|
||
"kind": "continuous",
|
||
},
|
||
"swiss_roll_hole": {
|
||
"name": "Swiss Roll (hole)",
|
||
"path": "sklearn.datasets.make_swiss_roll",
|
||
"kwargs": {"hole": True},
|
||
"description": (
|
||
"Swiss roll with a rectangular hole punched through. Same manifold, "
|
||
"non-trivial topology — a faithful unroll should preserve the hole "
|
||
"rather than smearing it closed."
|
||
),
|
||
"kind": "continuous",
|
||
},
|
||
"blobs": {
|
||
"name": "Gaussian Blobs",
|
||
"path": "sklearn.datasets.make_blobs",
|
||
"kwargs": {"n_features": 3, "centers": 5, "cluster_std": 1.0},
|
||
"description": (
|
||
"Five isotropic Gaussian clusters in R³. Discrete class labels. "
|
||
"Tests whether a reducer preserves cluster separation when "
|
||
"projected to 2-D."
|
||
),
|
||
"kind": "categorical",
|
||
},
|
||
"gaussian_quantiles": {
|
||
"name": "Gaussian Quantiles",
|
||
"path": "sklearn.datasets.make_gaussian_quantiles",
|
||
"kwargs": {"n_features": 3, "n_classes": 4},
|
||
"description": (
|
||
"Concentric Gaussian shells in R³; class = which shell. Classes "
|
||
"are linearly inseparable by construction — PCA collapses them, "
|
||
"kernel and manifold methods have a chance."
|
||
),
|
||
"kind": "categorical",
|
||
},
|
||
"classification": {
|
||
"name": "Hypercube Clusters",
|
||
"path": "sklearn.datasets.make_classification",
|
||
"kwargs": {
|
||
"n_features": 3,
|
||
"n_informative": 3,
|
||
"n_redundant": 0,
|
||
"n_repeated": 0,
|
||
"n_classes": 4,
|
||
"n_clusters_per_class": 2,
|
||
"class_sep": 1.5,
|
||
},
|
||
"description": (
|
||
"Four classes, two sub-clusters each, placed at hypercube vertices "
|
||
"with informative noise. A denser discrete test than blobs — "
|
||
"within-class bimodality stresses cluster-preserving reducers."
|
||
),
|
||
"kind": "categorical",
|
||
},
|
||
}
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def _dataset_previews() -> Dict[str, Dict[str, Any]]:
|
||
"""Attach freshly-generated points+labels to the catalogue for the picker."""
|
||
N, SEED = DATASET_PREVIEW_N, DATASET_PREVIEW_SEED
|
||
s, sl = make_s_curve(n_samples=N, noise=0.03, random_state=SEED)
|
||
sr, srl = make_swiss_roll(n_samples=N, noise=0.15, random_state=SEED)
|
||
srh, srhl = make_swiss_roll(n_samples=N, noise=0.15, hole=True, random_state=SEED)
|
||
b, bl = make_blobs(
|
||
n_samples=N, n_features=3, centers=5, cluster_std=1.0, random_state=SEED
|
||
)
|
||
gq, gql = make_gaussian_quantiles(
|
||
n_samples=N, n_features=3, n_classes=4, random_state=SEED
|
||
)
|
||
cls, clsl = make_classification(
|
||
n_samples=N,
|
||
n_features=3,
|
||
n_informative=3,
|
||
n_redundant=0,
|
||
n_repeated=0,
|
||
n_classes=4,
|
||
n_clusters_per_class=2,
|
||
class_sep=1.5,
|
||
random_state=SEED,
|
||
)
|
||
samples = {
|
||
"s_curve": (s, sl),
|
||
"swiss_roll": (sr, srl),
|
||
"swiss_roll_hole": (srh, srhl),
|
||
"blobs": (b, bl),
|
||
"gaussian_quantiles": (gq, gql),
|
||
"classification": (cls, clsl),
|
||
}
|
||
out: Dict[str, Dict[str, Any]] = {}
|
||
for key, meta in DATASET_META.items():
|
||
pts, labels = samples[key]
|
||
out[key] = {
|
||
**meta,
|
||
"points": pts.tolist(),
|
||
"labels": labels.tolist(),
|
||
}
|
||
return out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Reducer catalogue
|
||
# ---------------------------------------------------------------------------
|
||
# Each field tuple: (name, kind, default, choices_or_none, help_or_none)
|
||
# kinds: "int", "float", "str", "bool", "str_or_float", "int_or_null"
|
||
|
||
REDUCERS: Dict[str, Dict[str, Any]] = {
|
||
"sklearn.decomposition.PCA": {
|
||
"pkg": "sklearn",
|
||
"label": "PCA",
|
||
"blurb": "Principal component analysis. Linear, fast, deterministic.",
|
||
"key": [
|
||
("n_components", "int", 2, None, "Locked."),
|
||
],
|
||
"advanced": [
|
||
("svd_solver", "str", "auto", ["auto", "full", "arpack", "randomized"], None),
|
||
("random_state", "int", 42, None, None),
|
||
("whiten", "bool", False, None, None),
|
||
],
|
||
},
|
||
"sklearn.decomposition.FactorAnalysis": {
|
||
"pkg": "sklearn",
|
||
"label": "FactorAnalysis",
|
||
"blurb": "Gaussian latent-factor model with per-feature noise.",
|
||
"key": [
|
||
("n_components", "int", 2, None, "Locked."),
|
||
("random_state", "int", 42, None, None),
|
||
],
|
||
"advanced": [
|
||
("tol", "float", 0.01, None, None),
|
||
("max_iter", "int", 1000, None, None),
|
||
("rotation", "str", "", ["", "varimax", "quartimax"], "Empty = None."),
|
||
],
|
||
},
|
||
"sklearn.manifold.TSNE": {
|
||
"pkg": "sklearn",
|
||
"label": "t-SNE",
|
||
"blurb": "Stochastic neighbour embedding. Local structure preserved.",
|
||
"key": [
|
||
("n_components", "int", 2, None, "Locked."),
|
||
("perplexity", "float", 30.0, None, None),
|
||
("random_state", "int", 42, None, None),
|
||
],
|
||
"advanced": [
|
||
("learning_rate", "str_or_float", "auto", None, "'auto' or a float."),
|
||
("n_iter", "int", 1000, None, None),
|
||
("metric", "str", "euclidean", None, None),
|
||
("early_exaggeration", "float", 12.0, None, None),
|
||
("init", "str", "pca", ["pca", "random"], None),
|
||
],
|
||
},
|
||
"umap.UMAP": {
|
||
"pkg": "umap",
|
||
"label": "UMAP",
|
||
"blurb": "Uniform manifold approximation. Preserves local + some global structure.",
|
||
"key": [
|
||
("n_components", "int", 2, None, "Locked."),
|
||
("n_neighbors", "int", 15, None, None),
|
||
("min_dist", "float", 0.1, None, None),
|
||
("random_state", "int", 42, None, None),
|
||
],
|
||
"advanced": [
|
||
("metric", "str", "euclidean", None, None),
|
||
("n_epochs", "int_or_null", "", None, "Empty = None (auto)."),
|
||
("spread", "float", 1.0, None, None),
|
||
("init", "str", "spectral", ["spectral", "random"], None),
|
||
],
|
||
},
|
||
"pacmap.PaCMAP": {
|
||
"pkg": "pacmap",
|
||
"label": "PaCMAP",
|
||
"blurb": "Pairwise-controlled manifold approximation. Balanced local/global.",
|
||
"key": [
|
||
("n_components", "int", 2, None, "Locked."),
|
||
("n_neighbors", "int", 10, None, None),
|
||
("MN_ratio", "float", 0.5, None, None),
|
||
("FP_ratio", "float", 2.0, None, None),
|
||
("random_state", "int", 42, None, None),
|
||
],
|
||
"advanced": [
|
||
("lr", "float", 1.0, None, None),
|
||
("num_iters", "int", 450, None, None),
|
||
("apply_pca", "bool", True, None, None),
|
||
],
|
||
},
|
||
"trimap.TRIMAP": {
|
||
"pkg": "trimap",
|
||
"label": "TriMap",
|
||
"blurb": "Triplet-based dimensionality reduction. Emphasises global structure.",
|
||
"key": [
|
||
("n_dims", "int", 2, None, "Locked."),
|
||
("n_inliers", "int", 10, None, None),
|
||
("n_outliers", "int", 5, None, None),
|
||
("n_random", "int", 5, None, None),
|
||
],
|
||
"advanced": [
|
||
("lr", "float", 0.1, None, None),
|
||
("n_iters", "int", 400, None, None),
|
||
("weight_adj", "float", 500.0, None, None),
|
||
],
|
||
},
|
||
}
|
||
|
||
|
||
def available_reducers() -> List[Tuple[str, Dict[str, Any]]]:
|
||
out = []
|
||
for key, spec in REDUCERS.items():
|
||
if importlib.util.find_spec(spec["pkg"]) is not None:
|
||
out.append((key, spec))
|
||
return out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Parameter coercion
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _coerce(kind: str, raw: Optional[str], default: Any) -> Any:
|
||
if raw is None:
|
||
return default
|
||
s = raw.strip() if isinstance(raw, str) else raw
|
||
if kind == "int":
|
||
if s == "" or s is None:
|
||
return default
|
||
return int(s)
|
||
if kind == "float":
|
||
if s == "" or s is None:
|
||
return default
|
||
return float(s)
|
||
if kind == "bool":
|
||
# Checkbox: "on" / absent
|
||
return bool(s) and s not in ("0", "false", "False", "")
|
||
if kind == "str":
|
||
if s == "":
|
||
return None if default in (None, "") else default if default else ""
|
||
return s
|
||
if kind == "str_or_float":
|
||
if s == "":
|
||
return default
|
||
try:
|
||
return float(s)
|
||
except (ValueError, TypeError):
|
||
return s
|
||
if kind == "int_or_null":
|
||
if s == "":
|
||
return None
|
||
return int(s)
|
||
return s
|
||
|
||
|
||
def build_embed_args(reducer_key: str, form: Dict[str, str]) -> Dict[str, Any]:
|
||
spec = REDUCERS[reducer_key]
|
||
out: Dict[str, Any] = {}
|
||
all_fields = list(spec["key"]) + list(spec["advanced"])
|
||
for (name, kind, default, _choices, _help) in all_fields:
|
||
raw = form.get(f"embed__{name}")
|
||
if kind == "bool":
|
||
raw_v = "on" if f"embed__{name}" in form else ""
|
||
value = bool(raw_v)
|
||
else:
|
||
value = _coerce(kind, raw, default)
|
||
# Null-stripping: drop empty rotations etc.
|
||
if value is None:
|
||
continue
|
||
if isinstance(value, str) and value == "" and default in (None, ""):
|
||
continue
|
||
out[name] = value
|
||
|
||
# Always force n_components / n_dims to 2 (flow assertion)
|
||
if "n_components" in out:
|
||
out["n_components"] = 2
|
||
if "n_dims" in out:
|
||
out["n_dims"] = 2
|
||
return out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Output-path synthesis (mirrors flows/embedding_flow.py lines ~162–168)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def synthesize_output_paths(
|
||
generator_path: str,
|
||
embedder: str,
|
||
num_points: int,
|
||
num_timesteps: int,
|
||
jitter_scale: float,
|
||
seed: int,
|
||
) -> Tuple[str, str]:
|
||
gen = generator_path.split(".")[-1]
|
||
emb = embedder.split(".")[-1]
|
||
ref = f"{gen}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
||
embf = f"{gen}_{emb}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
||
return ref, embf
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Prefect client
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class Prefect:
|
||
def __init__(self, base: str = PREFECT_API) -> None:
|
||
self.base = base.rstrip("/")
|
||
self._deployment_id: Optional[str] = None
|
||
|
||
async def deployment_id(self, client: httpx.AsyncClient) -> Optional[str]:
|
||
if self._deployment_id:
|
||
return self._deployment_id
|
||
try:
|
||
r = await client.get(f"{self.base}/deployments/name/{DEPLOYMENT_NAME}")
|
||
if r.status_code == 200:
|
||
self._deployment_id = r.json()["id"]
|
||
return self._deployment_id
|
||
except httpx.HTTPError:
|
||
return None
|
||
return None
|
||
|
||
async def create_run(
|
||
self, client: httpx.AsyncClient, parameters: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
dep = await self.deployment_id(client)
|
||
if not dep:
|
||
return None
|
||
r = await client.post(
|
||
f"{self.base}/deployments/{dep}/create_flow_run",
|
||
json={"parameters": parameters},
|
||
)
|
||
if r.status_code >= 400:
|
||
return {"error": r.text, "status": r.status_code}
|
||
return r.json()
|
||
|
||
async def recent_runs(
|
||
self, client: httpx.AsyncClient, limit: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
dep = await self.deployment_id(client)
|
||
if not dep:
|
||
return []
|
||
try:
|
||
r = await client.post(
|
||
f"{self.base}/flow_runs/filter",
|
||
json={
|
||
"sort": "START_TIME_DESC",
|
||
"limit": limit,
|
||
"flow_runs": {"deployment_id": {"any_": [dep]}},
|
||
},
|
||
)
|
||
if r.status_code == 200:
|
||
return r.json()
|
||
except httpx.HTTPError:
|
||
return []
|
||
return []
|
||
|
||
|
||
PREFECT = Prefect()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# In-memory mapping: flow_run_id -> synthesized output file names
|
||
# (best-effort; lost on restart, which is fine)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
RUN_OUTPUTS: Dict[str, Dict[str, str]] = {}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# App
|
||
# ---------------------------------------------------------------------------
|
||
|
||
app = FastAPI(title="web1 — embedding notebook", docs_url=None, redoc_url=None)
|
||
|
||
app.mount("/figs", StaticFiles(directory=str(FIGS_DIR)), name="figs")
|
||
app.mount("/static", StaticFiles(directory=str(BASE_DIR / "static")), name="static")
|
||
|
||
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
|
||
|
||
|
||
def _fmt_runtime(seconds: Optional[float]) -> Optional[str]:
|
||
if seconds is None or seconds <= 0:
|
||
return None
|
||
if seconds < 60:
|
||
return f"{seconds:.1f}s"
|
||
m, s = divmod(int(seconds), 60)
|
||
if m < 60:
|
||
return f"{m}m{s:02d}s"
|
||
h, m = divmod(m, 60)
|
||
return f"{h}h{m:02d}m"
|
||
|
||
|
||
def _run_view(run: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Normalise a flow-run dict for the template."""
|
||
rid = run.get("id", "")
|
||
state_type = (run.get("state_type") or "PENDING").upper()
|
||
state_name = run.get("state_name") or state_type.title()
|
||
start = run.get("start_time") or run.get("expected_start_time") or run.get("created")
|
||
params = run.get("parameters") or {}
|
||
# estimated_run_time recomputes server-side (ticks up while RUNNING);
|
||
# total_run_time is the authoritative value once the run finishes.
|
||
rtime_s = run.get("estimated_run_time") or run.get("total_run_time")
|
||
# Try to look up synthesised outputs either from memory or from params
|
||
ref_file = None
|
||
emb_file = None
|
||
outs = RUN_OUTPUTS.get(rid)
|
||
if outs:
|
||
ref_file = outs["ref"]
|
||
emb_file = outs["embed"]
|
||
elif params:
|
||
try:
|
||
ref_file, emb_file = synthesize_output_paths(
|
||
params.get("generator_path", "sklearn.datasets.make_s_curve"),
|
||
params.get("embedder", "sklearn.decomposition.FactorAnalysis"),
|
||
int(params.get("num_points", 5000)),
|
||
# Fallback to the old num_snapshots key for runs dispatched
|
||
# before the T-rename, so historical figs still resolve after
|
||
# `rename 's/_S/_T/' figs/*.html`.
|
||
int(params.get("num_timesteps", params.get("num_snapshots", 48))),
|
||
float(params.get("jitter_scale", 0.01)),
|
||
int(params.get("seed", 42)),
|
||
)
|
||
except Exception:
|
||
ref_file, emb_file = None, None
|
||
|
||
ref_exists = bool(ref_file) and (FIGS_DIR / ref_file).exists()
|
||
emb_exists = bool(emb_file) and (FIGS_DIR / emb_file).exists()
|
||
|
||
return {
|
||
"id": rid,
|
||
"short_id": rid[:8] if rid else "",
|
||
"name": run.get("name", ""),
|
||
"state_type": state_type,
|
||
"state_name": state_name,
|
||
"start": start,
|
||
"runtime": _fmt_runtime(float(rtime_s) if rtime_s is not None else None),
|
||
"params": params,
|
||
"ref_file": ref_file,
|
||
"emb_file": emb_file,
|
||
"ref_exists": ref_exists,
|
||
"emb_exists": emb_exists,
|
||
"embedder_short": (params.get("embedder") or "").split(".")[-1],
|
||
"generator_short": (params.get("generator_path") or "").split(".")[-1],
|
||
}
|
||
|
||
|
||
def _reducer_choices() -> List[Dict[str, str]]:
|
||
return [
|
||
{"key": k, "label": spec["label"], "blurb": spec["blurb"]}
|
||
for k, spec in available_reducers()
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Routes
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index(request: Request) -> HTMLResponse:
|
||
reducers = _reducer_choices()
|
||
default_reducer = reducers[0]["key"] if reducers else None
|
||
default_spec = REDUCERS.get(default_reducer) if default_reducer else None
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
runs = await PREFECT.recent_runs(client, limit=10)
|
||
dep_id = await PREFECT.deployment_id(client)
|
||
views = [_run_view(r) for r in runs]
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"index.html",
|
||
{
|
||
"reducers": reducers,
|
||
"default_reducer": default_reducer,
|
||
"default_spec": default_spec,
|
||
"runs": views,
|
||
"deployment_id": dep_id,
|
||
"prefect_api": PREFECT_API,
|
||
},
|
||
)
|
||
|
||
|
||
@app.get("/data.json")
|
||
async def data_json() -> JSONResponse:
|
||
return JSONResponse(_dataset_previews())
|
||
|
||
|
||
@app.get("/reducer-form", response_class=HTMLResponse)
|
||
async def reducer_form(request: Request, name: str) -> HTMLResponse:
|
||
spec = REDUCERS.get(name)
|
||
if not spec:
|
||
return HTMLResponse("<p class='err'>unknown reducer</p>", status_code=404)
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"_reducer_form.html",
|
||
{"reducer_key": name, "spec": spec},
|
||
)
|
||
|
||
|
||
@app.get("/runs", response_class=HTMLResponse)
|
||
async def runs_partial(request: Request) -> HTMLResponse:
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
runs = await PREFECT.recent_runs(client, limit=10)
|
||
views = [_run_view(r) for r in runs]
|
||
return templates.TemplateResponse(
|
||
request, "_runs.html", {"runs": views}
|
||
)
|
||
|
||
|
||
@app.post("/submit", response_class=HTMLResponse)
|
||
async def submit(request: Request) -> HTMLResponse:
|
||
form = await request.form()
|
||
data: Dict[str, str] = {k: str(v) for k, v in form.items()}
|
||
|
||
reducer = data.get("reducer") or ""
|
||
if reducer not in REDUCERS:
|
||
return HTMLResponse(
|
||
f"<div class='flash err'>unknown reducer: {reducer}</div>",
|
||
status_code=400,
|
||
)
|
||
|
||
# Dataset came from the picker via dataset_id; fall back to explicit
|
||
# generator_path / generator_kwargs if a client posts those directly.
|
||
dataset_id = data.get("dataset_id") or ""
|
||
if dataset_id and dataset_id in DATASET_META:
|
||
meta = DATASET_META[dataset_id]
|
||
generator_path = meta["path"]
|
||
generator_kwargs = dict(meta["kwargs"])
|
||
else:
|
||
generator_path = data.get("generator_path") or "sklearn.datasets.make_s_curve"
|
||
raw_kwargs = data.get("generator_kwargs") or ""
|
||
try:
|
||
generator_kwargs = json.loads(raw_kwargs) if raw_kwargs else {}
|
||
except json.JSONDecodeError as e:
|
||
return HTMLResponse(
|
||
f"<div class='flash err'>bad generator_kwargs JSON: {e}</div>",
|
||
status_code=400,
|
||
)
|
||
|
||
try:
|
||
num_points = int(data.get("num_points", "5000") or 5000)
|
||
num_timesteps = int(data.get("num_timesteps", "48") or 48)
|
||
jitter_scale = float(data.get("jitter_scale", "0.01") or 0.01)
|
||
seed = int(data.get("seed", "42") or 42)
|
||
except ValueError as e:
|
||
return HTMLResponse(
|
||
f"<div class='flash err'>bad numeric input: {e}</div>", status_code=400
|
||
)
|
||
|
||
embed_args = build_embed_args(reducer, data)
|
||
|
||
parameters: Dict[str, Any] = {
|
||
"num_points": num_points,
|
||
"num_timesteps": num_timesteps,
|
||
"jitter_scale": jitter_scale,
|
||
"seed": seed,
|
||
"generator_path": generator_path,
|
||
"embedder": reducer,
|
||
"embed_args": embed_args,
|
||
}
|
||
if generator_kwargs:
|
||
parameters["generator_kwargs"] = generator_kwargs
|
||
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
run = await PREFECT.create_run(client, parameters)
|
||
|
||
if not run:
|
||
return HTMLResponse(
|
||
"<div class='flash err'>could not reach Prefect API at "
|
||
f"{PREFECT_API}</div>",
|
||
status_code=502,
|
||
)
|
||
if "error" in run:
|
||
return HTMLResponse(
|
||
f"<div class='flash err'>prefect error ({run.get('status')}): "
|
||
f"<code>{run.get('error')[:500]}</code></div>",
|
||
status_code=502,
|
||
)
|
||
|
||
ref_file, emb_file = synthesize_output_paths(
|
||
generator_path, reducer, num_points, num_timesteps, jitter_scale, seed
|
||
)
|
||
RUN_OUTPUTS[run["id"]] = {"ref": ref_file, "embed": emb_file}
|
||
|
||
# Return freshly refreshed runs partial so htmx can swap the right column
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
runs = await PREFECT.recent_runs(client, limit=10)
|
||
views = [_run_view(r) for r in runs]
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"_runs.html",
|
||
{"runs": views, "just_submitted": run["id"]},
|
||
)
|
||
|
||
|
||
def _scan_metrics() -> List[Dict[str, Any]]:
|
||
"""Read every `*.metrics.json` in FIGS_DIR and return them as a list."""
|
||
out: List[Dict[str, Any]] = []
|
||
for p in sorted(FIGS_DIR.glob("*.metrics.json"), key=lambda p: p.stat().st_mtime, reverse=True):
|
||
try:
|
||
data = json.loads(p.read_text())
|
||
except (OSError, json.JSONDecodeError):
|
||
continue
|
||
data["filename"] = p.name
|
||
data["embedding_file"] = p.name.replace(".metrics.json", ".html")
|
||
out.append(data)
|
||
return out
|
||
|
||
|
||
@app.get("/metrics", response_class=HTMLResponse)
|
||
async def metrics_page(request: Request) -> HTMLResponse:
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"metrics.html",
|
||
{"prefect_api": PREFECT_API},
|
||
)
|
||
|
||
|
||
@app.get("/metrics.json")
|
||
async def metrics_json() -> JSONResponse:
|
||
return JSONResponse(_scan_metrics())
|
||
|
||
|
||
@app.get("/health")
|
||
async def health() -> JSONResponse:
|
||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||
dep = await PREFECT.deployment_id(client)
|
||
return JSONResponse(
|
||
{"ok": True, "deployment_id": dep, "prefect_api": PREFECT_API}
|
||
)
|