compare: parse plotly HTML into frames JSON, expose at /api/runs/{stem}/frames.json

This commit is contained in:
Michael Pilosov 2026-04-22 14:16:30 -06:00
parent acb596743a
commit b016dbdaee
2 changed files with 321 additions and 2 deletions

View File

@ -13,13 +13,16 @@ from __future__ import annotations
import importlib.util import importlib.util
import json import json
import os import os
import re
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from app.web.plotly_parse import parse_plotly_run
import httpx import httpx
from fastapi import FastAPI, Form, Request from fastapi import FastAPI, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sklearn.datasets import ( from sklearn.datasets import (
@ -804,6 +807,32 @@ async def metrics_json() -> JSONResponse:
return JSONResponse(_scan_metrics()) return JSONResponse(_scan_metrics())
_STEM_RE = re.compile(
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$"
)
@lru_cache(maxsize=32)
def _cached_frames(stem: str) -> str:
"""Parse <stem>.html and return the frames dict as a JSON string."""
html = FIGS_DIR / f"{stem}.html"
return json.dumps(parse_plotly_run(html), separators=(",", ":"))
@app.get("/api/runs/{stem}/frames.json")
async def run_frames(stem: str) -> Response:
if not _STEM_RE.match(stem):
raise HTTPException(400, f"malformed stem: {stem!r}")
html = FIGS_DIR / f"{stem}.html"
if not html.is_file():
raise HTTPException(404, f"no such run: {stem}")
try:
payload = _cached_frames(stem)
except Exception as e:
raise HTTPException(500, f"parse failed: {e}")
return Response(content=payload, media_type="application/json")
@app.get("/health") @app.get("/health")
async def health() -> JSONResponse: async def health() -> JSONResponse:
async with httpx.AsyncClient(timeout=3.0) as client: async with httpx.AsyncClient(timeout=3.0) as client:

290
app/web/plotly_parse.py Normal file
View File

@ -0,0 +1,290 @@
"""Parse flow-emitted Plotly HTML files into a compact per-frame JSON dict.
The embedding-sandbox flow writes one standalone HTML per run at
``figs/<stem>.html``. Each HTML contains a ``Plotly.newPlot(...)`` call (which
holds the initial frame) and a ``Plotly.addFrames(...)`` call (which holds the
full animation including a redundant copy of frame 0). Point data is encoded
with plotly's base64+dtype ("bdata") scheme, so we decode it with ``struct``.
The module is stdlib-only.
"""
from __future__ import annotations
import base64
import json
import re
import struct
from pathlib import Path
_STEM_RE = re.compile(
r"^(?P<gen>make_.+?)_(?P<emb>[A-Za-z]+)_N(?P<n>\d+)_T(?P<t>\d+)"
r"_J(?P<j>[\d.]+)_s(?P<s>\d+)$"
)
# plotly's typed-array dtype -> (struct format char, item size bytes)
_DTYPE_MAP = {
"i1": ("b", 1),
"u1": ("B", 1),
"i2": ("h", 2),
"u2": ("H", 2),
"i4": ("i", 4),
"u4": ("I", 4),
"i8": ("q", 8),
"u8": ("Q", 8),
"f4": ("f", 4),
"f8": ("d", 8),
}
def _decode_bdata(obj):
# Accepts either a plotly typed-array dict {"dtype","bdata",...} or a plain list.
if isinstance(obj, list):
return list(obj)
if not isinstance(obj, dict) or "bdata" not in obj:
raise ValueError(f"expected bdata object or list, got {type(obj).__name__}")
dt = obj.get("dtype")
if dt not in _DTYPE_MAP:
raise ValueError(f"unsupported dtype {dt!r}")
fmt, size = _DTYPE_MAP[dt]
raw = base64.b64decode(obj["bdata"])
count = len(raw) // size
return list(struct.unpack("<" + fmt * count, raw[: count * size]))
def _extract_call_args(txt: str, call_name: str):
# Locate the LAST occurrence of `call_name(` in the file (plotly's JS
# bundle mentions Plotly.newPlot in docstrings/code earlier, but the real
# user-data call is always emitted at the very bottom).
start = txt.rfind(call_name + "(")
if start < 0:
raise ValueError(f"{call_name}( not found")
open_paren = txt.index("(", start)
i = open_paren + 1
depth = 1
in_str = False
escape = False
args: list[str] = []
cur = i
n = len(txt)
while i < n:
c = txt[i]
if in_str:
if escape:
escape = False
elif c == "\\":
escape = True
elif c == '"':
in_str = False
else:
if c == '"':
in_str = True
elif c in "([{":
depth += 1
elif c in ")]}":
depth -= 1
if depth == 0:
args.append(txt[cur:i].strip())
return args
elif c == "," and depth == 1:
args.append(txt[cur:i].strip())
cur = i + 1
i += 1
raise ValueError(f"unterminated {call_name} arglist")
def _parse_stem(stem: str) -> dict:
m = _STEM_RE.match(stem)
if not m:
raise ValueError(f"stem does not match expected pattern: {stem!r}")
return {
"stem": stem,
"generator": m.group("gen"),
"embedder": m.group("emb"),
"num_points": int(m.group("n")),
"num_timesteps": int(m.group("t")),
"jitter_scale": float(m.group("j")),
"seed": int(m.group("s")),
}
def _trace_xy_by_id(trace: dict) -> tuple[list[int], list[float], list[float]]:
# Prefer explicit `ids`; fall back to customdata (which in this flow also
# carries the point id in column 0); finally, fall back to positional index.
if "ids" in trace:
ids = _decode_bdata(trace["ids"])
# ids may arrive as strings too (plotly coerces). If they're strings
# of digits, we leave them as-is but prefer ints when possible.
if ids and isinstance(ids[0], str) and ids[0].isdigit():
ids = [int(v) for v in ids]
elif "customdata" in trace:
cd = trace["customdata"]
if isinstance(cd, dict):
ids = _decode_bdata(cd)
else:
# shape (N, k): pick first column
ids = [row[0] if isinstance(row, (list, tuple)) else row for row in cd]
else:
ids = list(range(len(_decode_bdata(trace["x"]))))
xs = _decode_bdata(trace["x"])
ys = _decode_bdata(trace["y"])
if not (len(ids) == len(xs) == len(ys)):
raise ValueError(
f"trace length mismatch: ids={len(ids)} x={len(xs)} y={len(ys)}"
)
return ids, xs, ys
def _frame_xy_ordered(traces: list[dict], point_ids: list[int]):
# Merge possibly multiple traces (one per label-group) back into a single
# (x, y) pair ordered by point_ids. Ids missing from this frame — points
# added or removed between frames by the jitter_add_remove path — become
# `None` so downstream code can skip them.
by_id: dict[int, tuple[float, float]] = {}
for t in traces:
ids, xs, ys = _trace_xy_by_id(t)
for pid, x, y in zip(ids, xs, ys):
by_id[pid] = (x, y)
xs_out = [by_id[pid][0] if pid in by_id else None for pid in point_ids]
ys_out = [by_id[pid][1] if pid in by_id else None for pid in point_ids]
return xs_out, ys_out
def _labels_for_ids(traces: list[dict], point_ids: list[int]) -> list[str]:
# If traces carry distinct names/legendgroups, use that as the per-point
# class label. When all traces share an empty name (the single-trace case
# emitted by this flow), labels are uniformly "".
names = [t.get("name") or t.get("legendgroup") or "" for t in traces]
if len(traces) <= 1 or all(n == names[0] for n in names):
return ["" for _ in point_ids]
by_id: dict[int, str] = {}
for t, name in zip(traces, names):
ids, _, _ = _trace_xy_by_id(t)
for pid in ids:
by_id[pid] = name
return [by_id.get(pid, "") for pid in point_ids]
def parse_plotly_run(html_path) -> dict:
"""Parse a flow-emitted plotly HTML into a frames dict suitable for
the three.js comparison page. Raises ValueError on unrecognised shape."""
path = Path(html_path)
stem = path.stem
meta = _parse_stem(stem)
txt = path.read_text(encoding="utf-8")
new_args = _extract_call_args(txt, "Plotly.newPlot")
if len(new_args) < 3:
raise ValueError(f"Plotly.newPlot expected >=3 args, got {len(new_args)}")
initial_traces = json.loads(new_args[1])
layout = json.loads(new_args[2])
add_args = _extract_call_args(txt, "Plotly.addFrames")
if len(add_args) < 2:
raise ValueError(f"Plotly.addFrames expected 2 args, got {len(add_args)}")
raw_frames = json.loads(add_args[1])
if not isinstance(raw_frames, list) or not raw_frames:
raise ValueError("Plotly.addFrames second arg is not a non-empty list")
# Establish the stable id set as the UNION of ids across every frame —
# the flow's jitter_add_remove path intentionally adds/removes points
# between frames, so no single frame is authoritative. Order is: first
# appearance across frames, then by numeric id within each frame's new ids.
seen_ids: list[int] = []
seen_set: set = set()
for fr in raw_frames:
for t in fr.get("data") or []:
ids, _, _ = _trace_xy_by_id(t)
for pid in ids:
if pid not in seen_set:
seen_set.add(pid)
seen_ids.append(pid)
if not seen_ids:
raise ValueError("no point ids recovered from any frame")
first_frame_traces = raw_frames[0].get("data") or initial_traces
labels = _labels_for_ids(first_frame_traces, seen_ids)
frames_out = []
times = []
for i, fr in enumerate(raw_frames):
traces = fr.get("data")
if not traces:
raise ValueError(f"frame {i} has no 'data'")
xs, ys = _frame_xy_ordered(traces, seen_ids)
frames_out.append({"x": xs, "y": ys})
times.append(fr.get("name", str(i)))
# Global bounds across all frames (skipping None placeholders for
# points that only exist in some frames).
all_x = [v for fr in frames_out for v in fr["x"] if v is not None]
all_y = [v for fr in frames_out for v in fr["y"] if v is not None]
if not all_x or not all_y:
raise ValueError("no finite x/y values across frames")
xmin, xmax = min(all_x), max(all_x)
ymin, ymax = min(all_y), max(all_y)
title = ""
try:
title = layout["title"]["text"] if isinstance(layout.get("title"), dict) else str(layout.get("title") or "")
except Exception:
title = ""
meta["title"] = title
return {
"meta": meta,
"point_ids": seen_ids,
"labels": labels,
"times": times,
"frames": frames_out,
"bounds": {"x": [xmin, xmax], "y": [ymin, ymax]},
}
if __name__ == "__main__":
import sys
import traceback
figs_dir = Path("/home/mm/work/dr-sandbox/figs")
html_files = sorted(figs_dir.glob("*.html"))
if not html_files:
print("no .html files in", figs_dir)
sys.exit(1)
failures = []
for p in html_files:
try:
out = parse_plotly_run(p)
except Exception as e:
failures.append((p.name, str(e)))
print(f"FAIL {p.name}: {e}")
traceback.print_exc()
continue
m = out["meta"]
pids = out["point_ids"]
frames = out["frames"]
nT = len(frames)
f0 = frames[0]
f0x = [v for v in f0["x"] if v is not None]
f0y = [v for v in f0["y"] if v is not None]
xr = (min(f0x), max(f0x)) if f0x else (float("nan"), float("nan"))
yr = (min(f0y), max(f0y)) if f0y else (float("nan"), float("nan"))
consistent = all(len(fr["x"]) == len(pids) and len(fr["y"]) == len(pids) for fr in frames)
present_per_frame = [sum(v is not None for v in fr["x"]) for fr in frames]
present_rng = (min(present_per_frame), max(present_per_frame))
print(
f"OK {m['stem']} |ids|={len(pids)} T={nT} "
f"present/frame=[{present_rng[0]},{present_rng[1]}] "
f"f0 x=[{xr[0]:+.3f},{xr[1]:+.3f}] y=[{yr[0]:+.3f},{yr[1]:+.3f}] "
f"consistent={consistent}"
)
print()
if failures:
print(f"{len(failures)} failure(s):")
for name, msg in failures:
print(f" {name}: {msg}")
else:
print(f"all {len(html_files)} files parsed OK")