diff --git a/app/web/main.py b/app/web/main.py index 6f745ca..c1aae6c 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -13,13 +13,16 @@ from __future__ import annotations import importlib.util import json import os +import re from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +from app.web.plotly_parse import parse_plotly_run + import httpx -from fastapi import FastAPI, Form, Request -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, Form, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse, Response from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sklearn.datasets import ( @@ -804,6 +807,32 @@ async def metrics_json() -> JSONResponse: return JSONResponse(_scan_metrics()) +_STEM_RE = re.compile( + r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$" +) + + +@lru_cache(maxsize=32) +def _cached_frames(stem: str) -> str: + """Parse .html and return the frames dict as a JSON string.""" + html = FIGS_DIR / f"{stem}.html" + return json.dumps(parse_plotly_run(html), separators=(",", ":")) + + +@app.get("/api/runs/{stem}/frames.json") +async def run_frames(stem: str) -> Response: + if not _STEM_RE.match(stem): + raise HTTPException(400, f"malformed stem: {stem!r}") + html = FIGS_DIR / f"{stem}.html" + if not html.is_file(): + raise HTTPException(404, f"no such run: {stem}") + try: + payload = _cached_frames(stem) + except Exception as e: + raise HTTPException(500, f"parse failed: {e}") + return Response(content=payload, media_type="application/json") + + @app.get("/health") async def health() -> JSONResponse: async with httpx.AsyncClient(timeout=3.0) as client: diff --git a/app/web/plotly_parse.py b/app/web/plotly_parse.py new file mode 100644 index 0000000..84d4561 --- /dev/null +++ b/app/web/plotly_parse.py @@ -0,0 +1,290 @@ +"""Parse flow-emitted Plotly HTML files into a compact per-frame JSON dict. + +The embedding-sandbox flow writes one standalone HTML per run at +``figs/.html``. Each HTML contains a ``Plotly.newPlot(...)`` call (which +holds the initial frame) and a ``Plotly.addFrames(...)`` call (which holds the +full animation — including a redundant copy of frame 0). Point data is encoded +with plotly's base64+dtype ("bdata") scheme, so we decode it with ``struct``. + +The module is stdlib-only. +""" + +from __future__ import annotations + +import base64 +import json +import re +import struct +from pathlib import Path + + +_STEM_RE = re.compile( + r"^(?Pmake_.+?)_(?P[A-Za-z]+)_N(?P\d+)_T(?P\d+)" + r"_J(?P[\d.]+)_s(?P\d+)$" +) + +# plotly's typed-array dtype -> (struct format char, item size bytes) +_DTYPE_MAP = { + "i1": ("b", 1), + "u1": ("B", 1), + "i2": ("h", 2), + "u2": ("H", 2), + "i4": ("i", 4), + "u4": ("I", 4), + "i8": ("q", 8), + "u8": ("Q", 8), + "f4": ("f", 4), + "f8": ("d", 8), +} + + +def _decode_bdata(obj): + # Accepts either a plotly typed-array dict {"dtype","bdata",...} or a plain list. + if isinstance(obj, list): + return list(obj) + if not isinstance(obj, dict) or "bdata" not in obj: + raise ValueError(f"expected bdata object or list, got {type(obj).__name__}") + dt = obj.get("dtype") + if dt not in _DTYPE_MAP: + raise ValueError(f"unsupported dtype {dt!r}") + fmt, size = _DTYPE_MAP[dt] + raw = base64.b64decode(obj["bdata"]) + count = len(raw) // size + return list(struct.unpack("<" + fmt * count, raw[: count * size])) + + +def _extract_call_args(txt: str, call_name: str): + # Locate the LAST occurrence of `call_name(` in the file (plotly's JS + # bundle mentions Plotly.newPlot in docstrings/code earlier, but the real + # user-data call is always emitted at the very bottom). + start = txt.rfind(call_name + "(") + if start < 0: + raise ValueError(f"{call_name}( not found") + open_paren = txt.index("(", start) + i = open_paren + 1 + depth = 1 + in_str = False + escape = False + args: list[str] = [] + cur = i + n = len(txt) + while i < n: + c = txt[i] + if in_str: + if escape: + escape = False + elif c == "\\": + escape = True + elif c == '"': + in_str = False + else: + if c == '"': + in_str = True + elif c in "([{": + depth += 1 + elif c in ")]}": + depth -= 1 + if depth == 0: + args.append(txt[cur:i].strip()) + return args + elif c == "," and depth == 1: + args.append(txt[cur:i].strip()) + cur = i + 1 + i += 1 + raise ValueError(f"unterminated {call_name} arglist") + + +def _parse_stem(stem: str) -> dict: + m = _STEM_RE.match(stem) + if not m: + raise ValueError(f"stem does not match expected pattern: {stem!r}") + return { + "stem": stem, + "generator": m.group("gen"), + "embedder": m.group("emb"), + "num_points": int(m.group("n")), + "num_timesteps": int(m.group("t")), + "jitter_scale": float(m.group("j")), + "seed": int(m.group("s")), + } + + +def _trace_xy_by_id(trace: dict) -> tuple[list[int], list[float], list[float]]: + # Prefer explicit `ids`; fall back to customdata (which in this flow also + # carries the point id in column 0); finally, fall back to positional index. + if "ids" in trace: + ids = _decode_bdata(trace["ids"]) + # ids may arrive as strings too (plotly coerces). If they're strings + # of digits, we leave them as-is but prefer ints when possible. + if ids and isinstance(ids[0], str) and ids[0].isdigit(): + ids = [int(v) for v in ids] + elif "customdata" in trace: + cd = trace["customdata"] + if isinstance(cd, dict): + ids = _decode_bdata(cd) + else: + # shape (N, k): pick first column + ids = [row[0] if isinstance(row, (list, tuple)) else row for row in cd] + else: + ids = list(range(len(_decode_bdata(trace["x"])))) + xs = _decode_bdata(trace["x"]) + ys = _decode_bdata(trace["y"]) + if not (len(ids) == len(xs) == len(ys)): + raise ValueError( + f"trace length mismatch: ids={len(ids)} x={len(xs)} y={len(ys)}" + ) + return ids, xs, ys + + +def _frame_xy_ordered(traces: list[dict], point_ids: list[int]): + # Merge possibly multiple traces (one per label-group) back into a single + # (x, y) pair ordered by point_ids. Ids missing from this frame — points + # added or removed between frames by the jitter_add_remove path — become + # `None` so downstream code can skip them. + by_id: dict[int, tuple[float, float]] = {} + for t in traces: + ids, xs, ys = _trace_xy_by_id(t) + for pid, x, y in zip(ids, xs, ys): + by_id[pid] = (x, y) + xs_out = [by_id[pid][0] if pid in by_id else None for pid in point_ids] + ys_out = [by_id[pid][1] if pid in by_id else None for pid in point_ids] + return xs_out, ys_out + + +def _labels_for_ids(traces: list[dict], point_ids: list[int]) -> list[str]: + # If traces carry distinct names/legendgroups, use that as the per-point + # class label. When all traces share an empty name (the single-trace case + # emitted by this flow), labels are uniformly "". + names = [t.get("name") or t.get("legendgroup") or "" for t in traces] + if len(traces) <= 1 or all(n == names[0] for n in names): + return ["" for _ in point_ids] + by_id: dict[int, str] = {} + for t, name in zip(traces, names): + ids, _, _ = _trace_xy_by_id(t) + for pid in ids: + by_id[pid] = name + return [by_id.get(pid, "") for pid in point_ids] + + +def parse_plotly_run(html_path) -> dict: + """Parse a flow-emitted plotly HTML into a frames dict suitable for + the three.js comparison page. Raises ValueError on unrecognised shape.""" + path = Path(html_path) + stem = path.stem + meta = _parse_stem(stem) + + txt = path.read_text(encoding="utf-8") + + new_args = _extract_call_args(txt, "Plotly.newPlot") + if len(new_args) < 3: + raise ValueError(f"Plotly.newPlot expected >=3 args, got {len(new_args)}") + initial_traces = json.loads(new_args[1]) + layout = json.loads(new_args[2]) + + add_args = _extract_call_args(txt, "Plotly.addFrames") + if len(add_args) < 2: + raise ValueError(f"Plotly.addFrames expected 2 args, got {len(add_args)}") + raw_frames = json.loads(add_args[1]) + if not isinstance(raw_frames, list) or not raw_frames: + raise ValueError("Plotly.addFrames second arg is not a non-empty list") + + # Establish the stable id set as the UNION of ids across every frame — + # the flow's jitter_add_remove path intentionally adds/removes points + # between frames, so no single frame is authoritative. Order is: first + # appearance across frames, then by numeric id within each frame's new ids. + seen_ids: list[int] = [] + seen_set: set = set() + for fr in raw_frames: + for t in fr.get("data") or []: + ids, _, _ = _trace_xy_by_id(t) + for pid in ids: + if pid not in seen_set: + seen_set.add(pid) + seen_ids.append(pid) + if not seen_ids: + raise ValueError("no point ids recovered from any frame") + + first_frame_traces = raw_frames[0].get("data") or initial_traces + labels = _labels_for_ids(first_frame_traces, seen_ids) + + frames_out = [] + times = [] + for i, fr in enumerate(raw_frames): + traces = fr.get("data") + if not traces: + raise ValueError(f"frame {i} has no 'data'") + xs, ys = _frame_xy_ordered(traces, seen_ids) + frames_out.append({"x": xs, "y": ys}) + times.append(fr.get("name", str(i))) + + # Global bounds across all frames (skipping None placeholders for + # points that only exist in some frames). + all_x = [v for fr in frames_out for v in fr["x"] if v is not None] + all_y = [v for fr in frames_out for v in fr["y"] if v is not None] + if not all_x or not all_y: + raise ValueError("no finite x/y values across frames") + xmin, xmax = min(all_x), max(all_x) + ymin, ymax = min(all_y), max(all_y) + + title = "" + try: + title = layout["title"]["text"] if isinstance(layout.get("title"), dict) else str(layout.get("title") or "") + except Exception: + title = "" + meta["title"] = title + + return { + "meta": meta, + "point_ids": seen_ids, + "labels": labels, + "times": times, + "frames": frames_out, + "bounds": {"x": [xmin, xmax], "y": [ymin, ymax]}, + } + + +if __name__ == "__main__": + import sys + import traceback + + figs_dir = Path("/home/mm/work/dr-sandbox/figs") + html_files = sorted(figs_dir.glob("*.html")) + if not html_files: + print("no .html files in", figs_dir) + sys.exit(1) + + failures = [] + for p in html_files: + try: + out = parse_plotly_run(p) + except Exception as e: + failures.append((p.name, str(e))) + print(f"FAIL {p.name}: {e}") + traceback.print_exc() + continue + m = out["meta"] + pids = out["point_ids"] + frames = out["frames"] + nT = len(frames) + f0 = frames[0] + f0x = [v for v in f0["x"] if v is not None] + f0y = [v for v in f0["y"] if v is not None] + xr = (min(f0x), max(f0x)) if f0x else (float("nan"), float("nan")) + yr = (min(f0y), max(f0y)) if f0y else (float("nan"), float("nan")) + consistent = all(len(fr["x"]) == len(pids) and len(fr["y"]) == len(pids) for fr in frames) + present_per_frame = [sum(v is not None for v in fr["x"]) for fr in frames] + present_rng = (min(present_per_frame), max(present_per_frame)) + print( + f"OK {m['stem']} |ids|={len(pids)} T={nT} " + f"present/frame=[{present_rng[0]},{present_rng[1]}] " + f"f0 x=[{xr[0]:+.3f},{xr[1]:+.3f}] y=[{yr[0]:+.3f},{yr[1]:+.3f}] " + f"consistent={consistent}" + ) + + print() + if failures: + print(f"{len(failures)} failure(s):") + for name, msg in failures: + print(f" {name}: {msg}") + else: + print(f"all {len(html_files)} files parsed OK")