"""Parse flow-emitted Plotly HTML files into a compact per-frame JSON dict.
The embedding-sandbox flow writes one standalone HTML per run at
``figs/.html``. Each HTML contains a ``Plotly.newPlot(...)`` call (which
holds the initial frame) and a ``Plotly.addFrames(...)`` call (which holds the
full animation — including a redundant copy of frame 0). Point data is encoded
with plotly's base64+dtype ("bdata") scheme, so we decode it with ``struct``.
The module is stdlib-only.
"""
from __future__ import annotations
import base64
import json
import re
import struct
from pathlib import Path
_STEM_RE = re.compile(
r"^(?Pmake_.+?)_(?P[A-Za-z]+)_N(?P\d+)_T(?P\d+)"
r"_J(?P[\d.]+)_s(?P\d+)(?:_(?P[0-9a-f]{8}))?$"
)
# plotly's typed-array dtype -> (struct format char, item size bytes)
_DTYPE_MAP = {
"i1": ("b", 1),
"u1": ("B", 1),
"i2": ("h", 2),
"u2": ("H", 2),
"i4": ("i", 4),
"u4": ("I", 4),
"i8": ("q", 8),
"u8": ("Q", 8),
"f4": ("f", 4),
"f8": ("d", 8),
}
def _decode_bdata(obj):
# Accepts either a plotly typed-array dict {"dtype","bdata",...} or a plain list.
if isinstance(obj, list):
return list(obj)
if not isinstance(obj, dict) or "bdata" not in obj:
raise ValueError(f"expected bdata object or list, got {type(obj).__name__}")
dt = obj.get("dtype")
if dt not in _DTYPE_MAP:
raise ValueError(f"unsupported dtype {dt!r}")
fmt, size = _DTYPE_MAP[dt]
raw = base64.b64decode(obj["bdata"])
count = len(raw) // size
return list(struct.unpack("<" + fmt * count, raw[: count * size]))
def _extract_call_args(txt: str, call_name: str):
# Locate the LAST occurrence of `call_name(` in the file (plotly's JS
# bundle mentions Plotly.newPlot in docstrings/code earlier, but the real
# user-data call is always emitted at the very bottom).
start = txt.rfind(call_name + "(")
if start < 0:
raise ValueError(f"{call_name}( not found")
open_paren = txt.index("(", start)
i = open_paren + 1
depth = 1
in_str = False
escape = False
args: list[str] = []
cur = i
n = len(txt)
while i < n:
c = txt[i]
if in_str:
if escape:
escape = False
elif c == "\\":
escape = True
elif c == '"':
in_str = False
else:
if c == '"':
in_str = True
elif c in "([{":
depth += 1
elif c in ")]}":
depth -= 1
if depth == 0:
args.append(txt[cur:i].strip())
return args
elif c == "," and depth == 1:
args.append(txt[cur:i].strip())
cur = i + 1
i += 1
raise ValueError(f"unterminated {call_name} arglist")
def _parse_stem(stem: str) -> dict:
m = _STEM_RE.match(stem)
if not m:
raise ValueError(f"stem does not match expected pattern: {stem!r}")
return {
"stem": stem,
"generator": m.group("gen"),
"embedder": m.group("emb"),
"num_points": int(m.group("n")),
"num_timesteps": int(m.group("t")),
"jitter_scale": float(m.group("j")),
"seed": int(m.group("s")),
}
def _trace_xy_by_id(trace: dict) -> tuple[list[int], list[float], list[float]]:
# Prefer explicit `ids`; fall back to customdata (which in this flow also
# carries the point id in column 0); finally, fall back to positional index.
if "ids" in trace:
ids = _decode_bdata(trace["ids"])
# ids may arrive as strings too (plotly coerces). If they're strings
# of digits, we leave them as-is but prefer ints when possible.
if ids and isinstance(ids[0], str) and ids[0].isdigit():
ids = [int(v) for v in ids]
elif "customdata" in trace:
cd = trace["customdata"]
if isinstance(cd, dict):
ids = _decode_bdata(cd)
else:
# shape (N, k): pick first column
ids = [row[0] if isinstance(row, (list, tuple)) else row for row in cd]
else:
ids = list(range(len(_decode_bdata(trace["x"]))))
xs = _decode_bdata(trace["x"])
ys = _decode_bdata(trace["y"])
if not (len(ids) == len(xs) == len(ys)):
raise ValueError(
f"trace length mismatch: ids={len(ids)} x={len(xs)} y={len(ys)}"
)
return ids, xs, ys
def _frame_xy_ordered(traces: list[dict], point_ids: list[int]):
# Merge possibly multiple traces (one per label-group) back into a single
# (x, y) pair ordered by point_ids. Ids missing from this frame — points
# added or removed between frames by the jitter_add_remove path — become
# `None` so downstream code can skip them.
by_id: dict[int, tuple[float, float]] = {}
for t in traces:
ids, xs, ys = _trace_xy_by_id(t)
for pid, x, y in zip(ids, xs, ys):
by_id[pid] = (x, y)
xs_out = [by_id[pid][0] if pid in by_id else None for pid in point_ids]
ys_out = [by_id[pid][1] if pid in by_id else None for pid in point_ids]
return xs_out, ys_out
def _labels_for_ids(traces: list[dict], point_ids: list[int]) -> list[str]:
# If traces carry distinct names/legendgroups, use that as the per-point
# class label. When all traces share an empty name (the single-trace case
# emitted by this flow), labels are uniformly "".
names = [t.get("name") or t.get("legendgroup") or "" for t in traces]
if len(traces) <= 1 or all(n == names[0] for n in names):
return ["" for _ in point_ids]
by_id: dict[int, str] = {}
for t, name in zip(traces, names):
ids, _, _ = _trace_xy_by_id(t)
for pid in ids:
by_id[pid] = name
return [by_id.get(pid, "") for pid in point_ids]
def parse_plotly_run(html_path) -> dict:
"""Parse a flow-emitted plotly HTML into a frames dict suitable for
the three.js comparison page. Raises ValueError on unrecognised shape."""
path = Path(html_path)
stem = path.stem
meta = _parse_stem(stem)
txt = path.read_text(encoding="utf-8")
new_args = _extract_call_args(txt, "Plotly.newPlot")
if len(new_args) < 3:
raise ValueError(f"Plotly.newPlot expected >=3 args, got {len(new_args)}")
initial_traces = json.loads(new_args[1])
layout = json.loads(new_args[2])
add_args = _extract_call_args(txt, "Plotly.addFrames")
if len(add_args) < 2:
raise ValueError(f"Plotly.addFrames expected 2 args, got {len(add_args)}")
raw_frames = json.loads(add_args[1])
if not isinstance(raw_frames, list) or not raw_frames:
raise ValueError("Plotly.addFrames second arg is not a non-empty list")
# Establish the stable id set as the UNION of ids across every frame —
# the flow's jitter_add_remove path intentionally adds/removes points
# between frames, so no single frame is authoritative. Order is: first
# appearance across frames, then by numeric id within each frame's new ids.
seen_ids: list[int] = []
seen_set: set = set()
for fr in raw_frames:
for t in fr.get("data") or []:
ids, _, _ = _trace_xy_by_id(t)
for pid in ids:
if pid not in seen_set:
seen_set.add(pid)
seen_ids.append(pid)
if not seen_ids:
raise ValueError("no point ids recovered from any frame")
first_frame_traces = raw_frames[0].get("data") or initial_traces
labels = _labels_for_ids(first_frame_traces, seen_ids)
frames_out = []
times = []
for i, fr in enumerate(raw_frames):
traces = fr.get("data")
if not traces:
raise ValueError(f"frame {i} has no 'data'")
xs, ys = _frame_xy_ordered(traces, seen_ids)
frames_out.append({"x": xs, "y": ys})
times.append(fr.get("name", str(i)))
# Global bounds across all frames (skipping None placeholders for
# points that only exist in some frames).
all_x = [v for fr in frames_out for v in fr["x"] if v is not None]
all_y = [v for fr in frames_out for v in fr["y"] if v is not None]
if not all_x or not all_y:
raise ValueError("no finite x/y values across frames")
xmin, xmax = min(all_x), max(all_x)
ymin, ymax = min(all_y), max(all_y)
title = ""
try:
title = layout["title"]["text"] if isinstance(layout.get("title"), dict) else str(layout.get("title") or "")
except Exception:
title = ""
meta["title"] = title
return {
"meta": meta,
"point_ids": seen_ids,
"labels": labels,
"times": times,
"frames": frames_out,
"bounds": {"x": [xmin, xmax], "y": [ymin, ymax]},
}
if __name__ == "__main__":
import sys
import traceback
figs_dir = Path("/home/mm/work/dr-sandbox/figs")
html_files = sorted(figs_dir.glob("*.html"))
if not html_files:
print("no .html files in", figs_dir)
sys.exit(1)
failures = []
for p in html_files:
try:
out = parse_plotly_run(p)
except Exception as e:
failures.append((p.name, str(e)))
print(f"FAIL {p.name}: {e}")
traceback.print_exc()
continue
m = out["meta"]
pids = out["point_ids"]
frames = out["frames"]
nT = len(frames)
f0 = frames[0]
f0x = [v for v in f0["x"] if v is not None]
f0y = [v for v in f0["y"] if v is not None]
xr = (min(f0x), max(f0x)) if f0x else (float("nan"), float("nan"))
yr = (min(f0y), max(f0y)) if f0y else (float("nan"), float("nan"))
consistent = all(len(fr["x"]) == len(pids) and len(fr["y"]) == len(pids) for fr in frames)
present_per_frame = [sum(v is not None for v in fr["x"]) for fr in frames]
present_rng = (min(present_per_frame), max(present_per_frame))
print(
f"OK {m['stem']} |ids|={len(pids)} T={nT} "
f"present/frame=[{present_rng[0]},{present_rng[1]}] "
f"f0 x=[{xr[0]:+.3f},{xr[1]:+.3f}] y=[{yr[0]:+.3f},{yr[1]:+.3f}] "
f"consistent={consistent}"
)
print()
if failures:
print(f"{len(failures)} failure(s):")
for name, msg in failures:
print(f" {name}: {msg}")
else:
print(f"all {len(html_files)} files parsed OK")