291 lines
10 KiB
Python
291 lines
10 KiB
Python
"""Parse flow-emitted Plotly HTML files into a compact per-frame JSON dict.
|
|
|
|
The embedding-sandbox flow writes one standalone HTML per run at
|
|
``figs/<stem>.html``. Each HTML contains a ``Plotly.newPlot(...)`` call (which
|
|
holds the initial frame) and a ``Plotly.addFrames(...)`` call (which holds the
|
|
full animation — including a redundant copy of frame 0). Point data is encoded
|
|
with plotly's base64+dtype ("bdata") scheme, so we decode it with ``struct``.
|
|
|
|
The module is stdlib-only.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import re
|
|
import struct
|
|
from pathlib import Path
|
|
|
|
|
|
_STEM_RE = re.compile(
|
|
r"^(?P<gen>make_.+?)_(?P<emb>[A-Za-z]+)_N(?P<n>\d+)_T(?P<t>\d+)"
|
|
r"_J(?P<j>[\d.]+)_s(?P<s>\d+)$"
|
|
)
|
|
|
|
# plotly's typed-array dtype -> (struct format char, item size bytes)
|
|
_DTYPE_MAP = {
|
|
"i1": ("b", 1),
|
|
"u1": ("B", 1),
|
|
"i2": ("h", 2),
|
|
"u2": ("H", 2),
|
|
"i4": ("i", 4),
|
|
"u4": ("I", 4),
|
|
"i8": ("q", 8),
|
|
"u8": ("Q", 8),
|
|
"f4": ("f", 4),
|
|
"f8": ("d", 8),
|
|
}
|
|
|
|
|
|
def _decode_bdata(obj):
|
|
# Accepts either a plotly typed-array dict {"dtype","bdata",...} or a plain list.
|
|
if isinstance(obj, list):
|
|
return list(obj)
|
|
if not isinstance(obj, dict) or "bdata" not in obj:
|
|
raise ValueError(f"expected bdata object or list, got {type(obj).__name__}")
|
|
dt = obj.get("dtype")
|
|
if dt not in _DTYPE_MAP:
|
|
raise ValueError(f"unsupported dtype {dt!r}")
|
|
fmt, size = _DTYPE_MAP[dt]
|
|
raw = base64.b64decode(obj["bdata"])
|
|
count = len(raw) // size
|
|
return list(struct.unpack("<" + fmt * count, raw[: count * size]))
|
|
|
|
|
|
def _extract_call_args(txt: str, call_name: str):
|
|
# Locate the LAST occurrence of `call_name(` in the file (plotly's JS
|
|
# bundle mentions Plotly.newPlot in docstrings/code earlier, but the real
|
|
# user-data call is always emitted at the very bottom).
|
|
start = txt.rfind(call_name + "(")
|
|
if start < 0:
|
|
raise ValueError(f"{call_name}( not found")
|
|
open_paren = txt.index("(", start)
|
|
i = open_paren + 1
|
|
depth = 1
|
|
in_str = False
|
|
escape = False
|
|
args: list[str] = []
|
|
cur = i
|
|
n = len(txt)
|
|
while i < n:
|
|
c = txt[i]
|
|
if in_str:
|
|
if escape:
|
|
escape = False
|
|
elif c == "\\":
|
|
escape = True
|
|
elif c == '"':
|
|
in_str = False
|
|
else:
|
|
if c == '"':
|
|
in_str = True
|
|
elif c in "([{":
|
|
depth += 1
|
|
elif c in ")]}":
|
|
depth -= 1
|
|
if depth == 0:
|
|
args.append(txt[cur:i].strip())
|
|
return args
|
|
elif c == "," and depth == 1:
|
|
args.append(txt[cur:i].strip())
|
|
cur = i + 1
|
|
i += 1
|
|
raise ValueError(f"unterminated {call_name} arglist")
|
|
|
|
|
|
def _parse_stem(stem: str) -> dict:
|
|
m = _STEM_RE.match(stem)
|
|
if not m:
|
|
raise ValueError(f"stem does not match expected pattern: {stem!r}")
|
|
return {
|
|
"stem": stem,
|
|
"generator": m.group("gen"),
|
|
"embedder": m.group("emb"),
|
|
"num_points": int(m.group("n")),
|
|
"num_timesteps": int(m.group("t")),
|
|
"jitter_scale": float(m.group("j")),
|
|
"seed": int(m.group("s")),
|
|
}
|
|
|
|
|
|
def _trace_xy_by_id(trace: dict) -> tuple[list[int], list[float], list[float]]:
|
|
# Prefer explicit `ids`; fall back to customdata (which in this flow also
|
|
# carries the point id in column 0); finally, fall back to positional index.
|
|
if "ids" in trace:
|
|
ids = _decode_bdata(trace["ids"])
|
|
# ids may arrive as strings too (plotly coerces). If they're strings
|
|
# of digits, we leave them as-is but prefer ints when possible.
|
|
if ids and isinstance(ids[0], str) and ids[0].isdigit():
|
|
ids = [int(v) for v in ids]
|
|
elif "customdata" in trace:
|
|
cd = trace["customdata"]
|
|
if isinstance(cd, dict):
|
|
ids = _decode_bdata(cd)
|
|
else:
|
|
# shape (N, k): pick first column
|
|
ids = [row[0] if isinstance(row, (list, tuple)) else row for row in cd]
|
|
else:
|
|
ids = list(range(len(_decode_bdata(trace["x"]))))
|
|
xs = _decode_bdata(trace["x"])
|
|
ys = _decode_bdata(trace["y"])
|
|
if not (len(ids) == len(xs) == len(ys)):
|
|
raise ValueError(
|
|
f"trace length mismatch: ids={len(ids)} x={len(xs)} y={len(ys)}"
|
|
)
|
|
return ids, xs, ys
|
|
|
|
|
|
def _frame_xy_ordered(traces: list[dict], point_ids: list[int]):
|
|
# Merge possibly multiple traces (one per label-group) back into a single
|
|
# (x, y) pair ordered by point_ids. Ids missing from this frame — points
|
|
# added or removed between frames by the jitter_add_remove path — become
|
|
# `None` so downstream code can skip them.
|
|
by_id: dict[int, tuple[float, float]] = {}
|
|
for t in traces:
|
|
ids, xs, ys = _trace_xy_by_id(t)
|
|
for pid, x, y in zip(ids, xs, ys):
|
|
by_id[pid] = (x, y)
|
|
xs_out = [by_id[pid][0] if pid in by_id else None for pid in point_ids]
|
|
ys_out = [by_id[pid][1] if pid in by_id else None for pid in point_ids]
|
|
return xs_out, ys_out
|
|
|
|
|
|
def _labels_for_ids(traces: list[dict], point_ids: list[int]) -> list[str]:
|
|
# If traces carry distinct names/legendgroups, use that as the per-point
|
|
# class label. When all traces share an empty name (the single-trace case
|
|
# emitted by this flow), labels are uniformly "".
|
|
names = [t.get("name") or t.get("legendgroup") or "" for t in traces]
|
|
if len(traces) <= 1 or all(n == names[0] for n in names):
|
|
return ["" for _ in point_ids]
|
|
by_id: dict[int, str] = {}
|
|
for t, name in zip(traces, names):
|
|
ids, _, _ = _trace_xy_by_id(t)
|
|
for pid in ids:
|
|
by_id[pid] = name
|
|
return [by_id.get(pid, "") for pid in point_ids]
|
|
|
|
|
|
def parse_plotly_run(html_path) -> dict:
|
|
"""Parse a flow-emitted plotly HTML into a frames dict suitable for
|
|
the three.js comparison page. Raises ValueError on unrecognised shape."""
|
|
path = Path(html_path)
|
|
stem = path.stem
|
|
meta = _parse_stem(stem)
|
|
|
|
txt = path.read_text(encoding="utf-8")
|
|
|
|
new_args = _extract_call_args(txt, "Plotly.newPlot")
|
|
if len(new_args) < 3:
|
|
raise ValueError(f"Plotly.newPlot expected >=3 args, got {len(new_args)}")
|
|
initial_traces = json.loads(new_args[1])
|
|
layout = json.loads(new_args[2])
|
|
|
|
add_args = _extract_call_args(txt, "Plotly.addFrames")
|
|
if len(add_args) < 2:
|
|
raise ValueError(f"Plotly.addFrames expected 2 args, got {len(add_args)}")
|
|
raw_frames = json.loads(add_args[1])
|
|
if not isinstance(raw_frames, list) or not raw_frames:
|
|
raise ValueError("Plotly.addFrames second arg is not a non-empty list")
|
|
|
|
# Establish the stable id set as the UNION of ids across every frame —
|
|
# the flow's jitter_add_remove path intentionally adds/removes points
|
|
# between frames, so no single frame is authoritative. Order is: first
|
|
# appearance across frames, then by numeric id within each frame's new ids.
|
|
seen_ids: list[int] = []
|
|
seen_set: set = set()
|
|
for fr in raw_frames:
|
|
for t in fr.get("data") or []:
|
|
ids, _, _ = _trace_xy_by_id(t)
|
|
for pid in ids:
|
|
if pid not in seen_set:
|
|
seen_set.add(pid)
|
|
seen_ids.append(pid)
|
|
if not seen_ids:
|
|
raise ValueError("no point ids recovered from any frame")
|
|
|
|
first_frame_traces = raw_frames[0].get("data") or initial_traces
|
|
labels = _labels_for_ids(first_frame_traces, seen_ids)
|
|
|
|
frames_out = []
|
|
times = []
|
|
for i, fr in enumerate(raw_frames):
|
|
traces = fr.get("data")
|
|
if not traces:
|
|
raise ValueError(f"frame {i} has no 'data'")
|
|
xs, ys = _frame_xy_ordered(traces, seen_ids)
|
|
frames_out.append({"x": xs, "y": ys})
|
|
times.append(fr.get("name", str(i)))
|
|
|
|
# Global bounds across all frames (skipping None placeholders for
|
|
# points that only exist in some frames).
|
|
all_x = [v for fr in frames_out for v in fr["x"] if v is not None]
|
|
all_y = [v for fr in frames_out for v in fr["y"] if v is not None]
|
|
if not all_x or not all_y:
|
|
raise ValueError("no finite x/y values across frames")
|
|
xmin, xmax = min(all_x), max(all_x)
|
|
ymin, ymax = min(all_y), max(all_y)
|
|
|
|
title = ""
|
|
try:
|
|
title = layout["title"]["text"] if isinstance(layout.get("title"), dict) else str(layout.get("title") or "")
|
|
except Exception:
|
|
title = ""
|
|
meta["title"] = title
|
|
|
|
return {
|
|
"meta": meta,
|
|
"point_ids": seen_ids,
|
|
"labels": labels,
|
|
"times": times,
|
|
"frames": frames_out,
|
|
"bounds": {"x": [xmin, xmax], "y": [ymin, ymax]},
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
import traceback
|
|
|
|
figs_dir = Path("/home/mm/work/dr-sandbox/figs")
|
|
html_files = sorted(figs_dir.glob("*.html"))
|
|
if not html_files:
|
|
print("no .html files in", figs_dir)
|
|
sys.exit(1)
|
|
|
|
failures = []
|
|
for p in html_files:
|
|
try:
|
|
out = parse_plotly_run(p)
|
|
except Exception as e:
|
|
failures.append((p.name, str(e)))
|
|
print(f"FAIL {p.name}: {e}")
|
|
traceback.print_exc()
|
|
continue
|
|
m = out["meta"]
|
|
pids = out["point_ids"]
|
|
frames = out["frames"]
|
|
nT = len(frames)
|
|
f0 = frames[0]
|
|
f0x = [v for v in f0["x"] if v is not None]
|
|
f0y = [v for v in f0["y"] if v is not None]
|
|
xr = (min(f0x), max(f0x)) if f0x else (float("nan"), float("nan"))
|
|
yr = (min(f0y), max(f0y)) if f0y else (float("nan"), float("nan"))
|
|
consistent = all(len(fr["x"]) == len(pids) and len(fr["y"]) == len(pids) for fr in frames)
|
|
present_per_frame = [sum(v is not None for v in fr["x"]) for fr in frames]
|
|
present_rng = (min(present_per_frame), max(present_per_frame))
|
|
print(
|
|
f"OK {m['stem']} |ids|={len(pids)} T={nT} "
|
|
f"present/frame=[{present_rng[0]},{present_rng[1]}] "
|
|
f"f0 x=[{xr[0]:+.3f},{xr[1]:+.3f}] y=[{yr[0]:+.3f},{yr[1]:+.3f}] "
|
|
f"consistent={consistent}"
|
|
)
|
|
|
|
print()
|
|
if failures:
|
|
print(f"{len(failures)} failure(s):")
|
|
for name, msg in failures:
|
|
print(f" {name}: {msg}")
|
|
else:
|
|
print(f"all {len(html_files)} files parsed OK")
|