dr-sandbox/app/web/static/compare.js

593 lines
20 KiB
JavaScript

// compare.js — side-by-side animated scatter for two embedding runs.
//
// Reads ?a=<stemA>&b=<stemB> from the URL, fetches /api/runs/<stem>/frames.json
// for each, and renders them into two linked three.js panels with a shared
// play/scrub/speed control strip. Linked hover: cursor on a point in one
// panel highlights the same point_id in the other.
import * as THREE from 'three';
// -------- URL / DOM wiring ------------------------------------------------
const params = new URLSearchParams(window.location.search);
const STEM_A = params.get('a') || '';
const STEM_B = params.get('b') || '';
const layout = document.getElementById('compare-layout');
const panelElA = layout.querySelector('.compare-panel[data-slot="a"]');
const panelElB = layout.querySelector('.compare-panel[data-slot="b"]');
const playBtn = document.getElementById('cc-play');
const scrub = document.getElementById('cc-scrub');
const speedSel = document.getElementById('cc-speed');
const syncSel = document.getElementById('cc-sync');
const motionSel = document.getElementById('cc-motion');
const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]');
const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]');
// -------- small helpers ---------------------------------------------------
// Build a soft-edged circular sprite for THREE.Points. A plain texture with
// radial alpha gives anti-aliased round dots without fragment-shader work.
function makeDiskTexture(size = 64) {
const c = document.createElement('canvas');
c.width = c.height = size;
const g = c.getContext('2d');
const r = size / 2;
const grd = g.createRadialGradient(r, r, 0, r, r, r);
grd.addColorStop(0.0, 'rgba(255,255,255,1)');
grd.addColorStop(0.55, 'rgba(255,255,255,1)');
grd.addColorStop(0.85, 'rgba(255,255,255,0.35)');
grd.addColorStop(1.0, 'rgba(255,255,255,0)');
g.fillStyle = grd;
g.fillRect(0, 0, size, size);
const tex = new THREE.CanvasTexture(c);
tex.needsUpdate = true;
return tex;
}
const DISK_TEX = makeDiskTexture();
function readVar(name, fallback) {
const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim();
return v || fallback;
}
function readPanelBg(el) {
const v = getComputedStyle(el).getPropertyValue('--picker-panel').trim();
return v || readVar('--panel', '#ffffff');
}
// Panel accent colors keyed by slot. Resolved from CSS vars so they flip
// with the theme via 'themechange'.
function panelAccent(slot) {
return slot === 'a' ? readVar('--accent', '#1f4e5f') : readVar('--warm', '#a77a2c');
}
function highlightColor() {
return readVar('--alarm', '#8a3a2a');
}
// -------- Panel factory ---------------------------------------------------
// Returns { setFrame, setBounds, setHighlight, onHover, resize, dispose, state }
function createPanel({ slotId, panelEl, data }) {
const canvasEl = panelEl.querySelector('[data-role="canvas"]');
const statusEl = panelEl.querySelector('[data-role="status"]');
const scene = new THREE.Scene();
scene.background = new THREE.Color(readPanelBg(panelEl));
const camera = new THREE.OrthographicCamera(-1, 1, 1, -1, -10, 10);
const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: false });
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
canvasEl.appendChild(renderer.domElement);
// Pre-allocate a buffer sized for num_points. Each frame we repack the
// non-null points into the prefix and call setDrawRange(0, count).
const maxN = data.point_ids.length;
const positions = new Float32Array(maxN * 3);
const ids = new Int32Array(maxN); // packed point_id per drawn vertex
const geo = new THREE.BufferGeometry();
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
geo.setDrawRange(0, 0);
const mat = new THREE.PointsMaterial({
size: 6.0,
sizeAttenuation: false,
map: DISK_TEX,
transparent: true,
alphaTest: 0.2,
depthWrite: false,
color: new THREE.Color(panelAccent(slotId)),
});
const points = new THREE.Points(geo, mat);
scene.add(points);
// Highlight overlay — a second 1-vertex Points object drawn on top.
const hiPos = new Float32Array(3);
const hiGeo = new THREE.BufferGeometry();
hiGeo.setAttribute('position', new THREE.BufferAttribute(hiPos, 3));
hiGeo.setDrawRange(0, 0);
const hiMat = new THREE.PointsMaterial({
size: 14.0,
sizeAttenuation: false,
map: DISK_TEX,
transparent: true,
alphaTest: 0.05,
depthWrite: false,
color: new THREE.Color(highlightColor()),
opacity: 0.9,
});
const hiPoints = new THREE.Points(hiGeo, hiMat);
// Ensure highlight renders above base points
hiPoints.renderOrder = 1;
scene.add(hiPoints);
// ---- current-frame packed state (kept in closure for hover picking) ----
// packedX/packedY/packedId of length = current draw count.
let packedN = 0;
const packedX = new Float32Array(maxN);
const packedY = new Float32Array(maxN);
const packedId = new Int32Array(maxN);
// Camera frame rectangle (world coords) currently applied.
const camRect = { xmin: -1, xmax: 1, ymin: -1, ymax: 1 };
function applyCamRect(rect) {
const pad = 0.05;
const dx = rect.xmax - rect.xmin || 1;
const dy = rect.ymax - rect.ymin || 1;
const cx = (rect.xmin + rect.xmax) / 2;
const cy = (rect.ymin + rect.ymax) / 2;
const rx = dx * (0.5 + pad);
const ry = dy * (0.5 + pad);
// Fit-to-larger-axis so points never get squashed when the panel aspect
// doesn't match the bounds aspect. We compute the viewport aspect here
// so the ortho frustum covers the data and then some.
const rect2 = canvasEl.getBoundingClientRect();
const vw = Math.max(1, rect2.width);
const vh = Math.max(1, rect2.height);
const aspect = vw / vh;
const dataAspect = (rx * 2) / (ry * 2);
let halfW, halfH;
if (aspect > dataAspect) {
// viewport wider than data: expand X to fit
halfH = ry;
halfW = ry * aspect;
} else {
halfW = rx;
halfH = rx / aspect;
}
camera.left = cx - halfW;
camera.right = cx + halfW;
camera.top = cy + halfH;
camera.bottom = cy - halfH;
camera.updateProjectionMatrix();
camRect.xmin = rect.xmin;
camRect.xmax = rect.xmax;
camRect.ymin = rect.ymin;
camRect.ymax = rect.ymax;
}
function resize() {
const rect = canvasEl.getBoundingClientRect();
const w = Math.max(1, Math.floor(rect.width));
const h = Math.max(1, Math.floor(rect.height));
renderer.setSize(w, h, false);
// Re-apply the current cam rect so the aspect fit recomputes.
applyCamRect(camRect);
}
function setBounds(b) {
applyCamRect({ xmin: b.x[0], xmax: b.x[1], ymin: b.y[0], ymax: b.y[1] });
}
function applyColorsFromTheme() {
mat.color.set(panelAccent(slotId));
hiMat.color.set(highlightColor());
scene.background = new THREE.Color(readPanelBg(panelEl));
}
// Pack frame `f` into the geometry buffer, skipping null x/y.
function setFrame(f) {
const frame = data.frames[f];
if (!frame) return;
const xs = frame.x, ys = frame.y;
const ptIds = data.point_ids;
let j = 0;
for (let i = 0; i < xs.length; i++) {
const x = xs[i], y = ys[i];
if (x === null || y === null || x === undefined || y === undefined
|| Number.isNaN(x) || Number.isNaN(y)) continue;
positions[j * 3 + 0] = x;
positions[j * 3 + 1] = y;
positions[j * 3 + 2] = 0;
packedX[j] = x;
packedY[j] = y;
packedId[j] = ptIds[i];
j++;
}
packedN = j;
geo.attributes.position.needsUpdate = true;
geo.setDrawRange(0, packedN);
applyHighlightForCurrentFrame();
}
// Pack an interpolated frame. uLocal is a continuous index in [0, T-1].
// Points missing in either adjacent frame are skipped for the duration of
// that transition (no connect-back to the last-known position).
function setFrameInterpolated(uLocal) {
const T = data.frames.length;
if (T === 0) return;
if (uLocal <= 0) return setFrame(0);
if (uLocal >= T - 1) return setFrame(T - 1);
const f0 = Math.floor(uLocal);
const f1 = f0 + 1;
const t = uLocal - f0;
const fr0 = data.frames[f0], fr1 = data.frames[f1];
const x0 = fr0.x, y0 = fr0.y, x1 = fr1.x, y1 = fr1.y;
const ptIds = data.point_ids;
let j = 0;
for (let i = 0; i < x0.length; i++) {
const a = x0[i], b = x1[i], c = y0[i], d = y1[i];
if (a === null || a === undefined || Number.isNaN(a)
|| b === null || b === undefined || Number.isNaN(b)
|| c === null || c === undefined || Number.isNaN(c)
|| d === null || d === undefined || Number.isNaN(d)) continue;
const xi = a + (b - a) * t;
const yi = c + (d - c) * t;
positions[j * 3 + 0] = xi;
positions[j * 3 + 1] = yi;
positions[j * 3 + 2] = 0;
packedX[j] = xi;
packedY[j] = yi;
packedId[j] = ptIds[i];
j++;
}
packedN = j;
geo.attributes.position.needsUpdate = true;
geo.setDrawRange(0, packedN);
applyHighlightForCurrentFrame();
}
// ---- highlight by point_id ---------------------------------------------
let currentHighlightId = -1;
function applyHighlightForCurrentFrame() {
if (currentHighlightId < 0) {
hiGeo.setDrawRange(0, 0);
return;
}
// Linear scan — packedN <= 5000 and this only runs on hover.
for (let i = 0; i < packedN; i++) {
if (packedId[i] === currentHighlightId) {
hiPos[0] = packedX[i];
hiPos[1] = packedY[i];
hiPos[2] = 0.01;
hiGeo.attributes.position.needsUpdate = true;
hiGeo.setDrawRange(0, 1);
return;
}
}
// Not present this frame.
hiGeo.setDrawRange(0, 0);
}
function setHighlight(pointId) {
currentHighlightId = (pointId === null || pointId === undefined) ? -1 : pointId;
applyHighlightForCurrentFrame();
}
// Convert clientX/Y to world coords (ortho, no rotation).
function clientToWorld(clientX, clientY) {
const rect = canvasEl.getBoundingClientRect();
const u = (clientX - rect.left) / Math.max(1, rect.width);
const v = (clientY - rect.top) / Math.max(1, rect.height);
const wx = camera.left + u * (camera.right - camera.left);
const wy = camera.top - v * (camera.top - camera.bottom);
return { wx, wy };
}
// Nearest-point pick with a screen-space radius cap. Returns point_id or -1.
function pickAt(clientX, clientY) {
if (packedN === 0) return -1;
const { wx, wy } = clientToWorld(clientX, clientY);
// Screen-pixel radius -> world radius
const rect = canvasEl.getBoundingClientRect();
const worldPerPx = (camera.right - camera.left) / Math.max(1, rect.width);
const pickPx = 14;
const maxR = pickPx * worldPerPx;
const maxR2 = maxR * maxR;
let bestI = -1;
let bestD2 = Infinity;
for (let i = 0; i < packedN; i++) {
const dx = packedX[i] - wx;
const dy = packedY[i] - wy;
const d2 = dx * dx + dy * dy;
if (d2 < bestD2 && d2 < maxR2) {
bestD2 = d2;
bestI = i;
}
}
return bestI < 0 ? -1 : packedId[bestI];
}
function render() {
renderer.render(scene, camera);
}
function dispose() {
geo.dispose();
hiGeo.dispose();
mat.dispose();
hiMat.dispose();
renderer.dispose();
if (renderer.domElement.parentNode) {
renderer.domElement.parentNode.removeChild(renderer.domElement);
}
}
// Hide the status overlay once we have data.
statusEl.style.display = 'none';
return {
slotId,
canvasEl,
panelEl,
data,
setFrame,
setFrameInterpolated,
setBounds,
setHighlight,
resize,
render,
pickAt,
applyColorsFromTheme,
dispose,
get packedN() { return packedN; },
get camRect() { return camRect; },
applyCamRect,
};
}
// -------- error rendering -------------------------------------------------
function renderError(panelEl, stem, msg) {
const statusEl = panelEl.querySelector('[data-role="status"]');
statusEl.style.display = '';
statusEl.classList.add('is-error');
statusEl.textContent = `could not load ${stem}: ${msg}`;
// Keep header readable
panelEl.querySelector('[data-role="embedder"]').textContent = '—';
panelEl.querySelector('[data-role="generator"]').textContent = '—';
panelEl.querySelector('[data-role="params"]').textContent = '(error)';
}
const PARAM_FIELDS = [
{ key: 'num_points', prefix: 'N' },
{ key: 'num_timesteps', prefix: 'T' },
{ key: 'jitter_scale', prefix: 'J' },
{ key: 'seed', prefix: 's' },
];
function renderHeader(panelEl, data) {
const m = data.meta || {};
panelEl.querySelector('[data-role="embedder"]').textContent = m.embedder || '—';
panelEl.querySelector('[data-role="generator"]').textContent = m.generator || '—';
const host = panelEl.querySelector('[data-role="params"]');
host.textContent = '';
PARAM_FIELDS.forEach(({ key, prefix }, i) => {
if (i > 0) host.appendChild(document.createTextNode(' / '));
const span = document.createElement('span');
span.className = 'param';
span.dataset.key = key;
span.textContent = `${prefix}${m[key] ?? '?'}`;
host.appendChild(span);
});
}
// Toggle .diff on each param span where the two panels disagree.
function markParamDiffs(metaA, metaB) {
if (!metaA || !metaB) return;
for (const { key } of PARAM_FIELDS) {
const differs = metaA[key] !== metaB[key];
for (const panelEl of [panelElA, panelElB]) {
const span = panelEl.querySelector(`[data-role="params"] .param[data-key="${key}"]`);
if (span) span.classList.toggle('diff', differs);
}
}
}
// -------- main ------------------------------------------------------------
async function fetchFrames(stem) {
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`);
if (!res.ok) {
throw new Error(`${res.status} ${res.statusText}`);
}
return res.json();
}
async function main() {
if (!STEM_A || !STEM_B) {
renderError(panelElA, STEM_A || '(missing)', 'no stem in ?a=');
renderError(panelElB, STEM_B || '(missing)', 'no stem in ?b=');
return;
}
// Fetch in parallel; each panel's failure is independent.
const [resA, resB] = await Promise.allSettled([
fetchFrames(STEM_A),
fetchFrames(STEM_B),
]);
const panels = { a: null, b: null };
if (resA.status === 'fulfilled') {
renderHeader(panelElA, resA.value);
panels.a = createPanel({ slotId: 'a', panelEl: panelElA, data: resA.value });
} else {
renderError(panelElA, STEM_A, resA.reason?.message || String(resA.reason));
}
if (resB.status === 'fulfilled') {
renderHeader(panelElB, resB.value);
panels.b = createPanel({ slotId: 'b', panelEl: panelElB, data: resB.value });
} else {
renderError(panelElB, STEM_B, resB.reason?.message || String(resB.reason));
}
if (panels.a && panels.b) markParamDiffs(panels.a.data.meta, panels.b.data.meta);
// Nothing loaded — no animation loop to start.
if (!panels.a && !panels.b) return;
// Initial bounds + first frame for each panel.
for (const p of Object.values(panels)) {
if (!p) continue;
p.setBounds(p.data.bounds);
p.resize();
p.setFrame(0);
}
// ---- time mapping -----------------------------------------------------
// Scrubber is 0..1000. Each panel picks round(u * (T-1)) as its frame idx.
const SCRUB_MAX = 1000;
function framesOf(p) { return p ? p.data.frames.length : 0; }
function timeLabelFor(p, u) {
if (!p) return '—';
const T = framesOf(p);
if (T <= 0) return '—';
const idx = Math.max(0, Math.min(T - 1, Math.round(u * (T - 1))));
return p.data.times?.[idx] ?? String(idx);
}
function applyU(u) {
u = Math.max(0, Math.min(1, u));
const smooth = motionSel.value === 'smooth';
for (const p of Object.values(panels)) {
if (!p) continue;
const T = framesOf(p);
if (T <= 0) continue;
const uLocal = u * (T - 1);
if (smooth) {
p.setFrameInterpolated(uLocal);
} else {
const idx = Math.max(0, Math.min(T - 1, Math.round(uLocal)));
p.setFrame(idx);
}
}
timeAEl.textContent = timeLabelFor(panels.a, u);
timeBEl.textContent = timeLabelFor(panels.b, u);
}
// ---- axes sync mode ---------------------------------------------------
function applySync() {
const mode = syncSel.value;
if (mode === 'locked' && panels.a && panels.b) {
const ba = panels.a.data.bounds, bb = panels.b.data.bounds;
const union = {
x: [Math.min(ba.x[0], bb.x[0]), Math.max(ba.x[1], bb.x[1])],
y: [Math.min(ba.y[0], bb.y[0]), Math.max(ba.y[1], bb.y[1])],
};
panels.a.setBounds(union);
panels.b.setBounds(union);
} else {
if (panels.a) panels.a.setBounds(panels.a.data.bounds);
if (panels.b) panels.b.setBounds(panels.b.data.bounds);
}
}
syncSel.addEventListener('change', applySync);
applySync();
// ---- play loop --------------------------------------------------------
// Base step: 400ms per frame at 1x, divided by speed. The loop advances
// the scrub by (1 / maxT) per step so both panels traverse their whole
// timeline in the same wall-clock duration when T differs.
let playing = false;
let lastTs = 0;
function maxT() {
const ta = framesOf(panels.a);
const tb = framesOf(panels.b);
return Math.max(ta, tb, 2);
}
function baseMsPerFrame() { return 1600 / parseFloat(speedSel.value || '1'); }
function tick(ts) {
requestAnimationFrame(tick);
for (const p of Object.values(panels)) p?.render();
if (!playing) { lastTs = ts; return; }
if (!lastTs) lastTs = ts;
const dt = ts - lastTs;
const perFrame = baseMsPerFrame();
const T = maxT();
const du = dt / (perFrame * (T - 1));
if (du > 0) {
let u = parseFloat(scrub.value) / SCRUB_MAX + du;
if (u >= 1) u -= Math.floor(u); // wrap 0..1
scrub.value = String(Math.round(u * SCRUB_MAX));
applyU(u);
lastTs = ts;
}
}
requestAnimationFrame(tick);
function setPlaying(v) {
playing = v;
lastTs = 0;
playBtn.textContent = playing ? '▮▮' : '▶';
playBtn.setAttribute('aria-label', playing ? 'pause' : 'play');
}
playBtn.addEventListener('click', () => setPlaying(!playing));
scrub.addEventListener('input', () => {
applyU(parseFloat(scrub.value) / SCRUB_MAX);
});
speedSel.addEventListener('change', () => { lastTs = 0; });
motionSel.addEventListener('change', () => {
applyU(parseFloat(scrub.value) / SCRUB_MAX);
});
// ---- linked hover -----------------------------------------------------
function wireHover(pA, pB) {
if (!pA) return;
const el = pA.canvasEl;
el.addEventListener('mousemove', (ev) => {
const id = pA.pickAt(ev.clientX, ev.clientY);
pA.setHighlight(id >= 0 ? id : null);
if (pB) pB.setHighlight(id >= 0 ? id : null);
});
el.addEventListener('mouseleave', () => {
pA.setHighlight(null);
if (pB) pB.setHighlight(null);
});
}
wireHover(panels.a, panels.b);
wireHover(panels.b, panels.a);
// ---- resize + theme ---------------------------------------------------
const ro = new ResizeObserver(() => {
for (const p of Object.values(panels)) p?.resize();
});
if (panels.a) ro.observe(panels.a.canvasEl);
if (panels.b) ro.observe(panels.b.canvasEl);
document.addEventListener('themechange', () => {
for (const p of Object.values(panels)) p?.applyColorsFromTheme();
});
// Initialise the label + scrub at 0.
applyU(0);
}
main().catch((err) => {
console.error(err);
renderError(panelElA, STEM_A, 'init failed: ' + err.message);
renderError(panelElB, STEM_B, 'init failed: ' + err.message);
});