diff --git a/app/web/main.py b/app/web/main.py index 9b1c092..21e5861 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -811,12 +811,48 @@ _STEM_RE = re.compile( r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$" ) +# Map short generator name ("make_blobs") to its DATASET_META entry. +# swiss_roll and swiss_roll_hole collide on path; first wins (plain variant). +_GEN_TO_META: Dict[str, Dict[str, Any]] = {} +for _m in DATASET_META.values(): + _GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m) + + +def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]: + """Attach per-point class/continuous labels by regenerating the dataset + with the same (generator, n_samples, kwargs). The stem's `seed` drives + jitter — NOT generator — so we always use random_state=0 to match the + flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points) + get None so the client renders them as black.""" + meta = _GEN_TO_META.get(d["meta"].get("generator") or "") + if not meta: + return d + try: + mod_path, cls_name = meta["path"].rsplit(".", 1) + fn = getattr(importlib.import_module(mod_path), cls_name) + N = int(d["meta"]["num_points"]) + _, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"]) + out_labels: List[Optional[float]] = [] + for pid in d["point_ids"]: + if isinstance(pid, int) and 0 <= pid < N: + v = gen_labels[pid] + out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None) + else: + out_labels.append(None) + d["labels"] = out_labels + d["label_kind"] = meta["kind"] + except Exception: + pass + return d + @lru_cache(maxsize=32) def _cached_frames(stem: str) -> str: """Parse .html and return the frames dict as a JSON string.""" html = FIGS_DIR / f"{stem}.html" - return json.dumps(parse_plotly_run(html), separators=(",", ":")) + d = parse_plotly_run(html) + d = _enrich_with_labels(d) + return json.dumps(d, separators=(",", ":")) @app.get("/api/runs/{stem}/frames.json") @@ -830,7 +866,11 @@ async def run_frames(stem: str) -> Response: payload = _cached_frames(stem) except Exception as e: raise HTTPException(500, f"parse failed: {e}") - return Response(content=payload, media_type="application/json") + return Response( + content=payload, + media_type="application/json", + headers={"Cache-Control": "no-cache"}, + ) @app.get("/compare", response_class=HTMLResponse) diff --git a/app/web/static/compare.js b/app/web/static/compare.js index f86af7d..b5fdbdf 100644 --- a/app/web/static/compare.js +++ b/app/web/static/compare.js @@ -22,6 +22,7 @@ const scrub = document.getElementById('cc-scrub'); const speedSel = document.getElementById('cc-speed'); const syncSel = document.getElementById('cc-sync'); const motionSel = document.getElementById('cc-motion'); +const colorSel = document.getElementById('cc-color'); const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]'); const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]'); @@ -48,6 +49,84 @@ function makeDiskTexture(size = 64) { const DISK_TEX = makeDiskTexture(); +// Continuous ramp (blue → orange), matching dataset-picker.js exactly. +function rampContinuous(t, out) { + const hue = (1 - t) * 215 + t * 28; + const sat = 0.62; + const lit = 0.50 + (t - 0.5) * 0.08; + return (out || new THREE.Color()).setHSL(hue / 360, sat, lit); +} + +// 8-color categorical palette — same hex list as dataset-picker.js. +const CATEGORICAL_HEX = [ + '#1f4e5f', '#c97b3f', '#8b5a9f', '#5a8560', + '#c74a5e', '#6b7d8f', '#b89f51', '#4a6fa5', +]; +const CATEGORICAL = CATEGORICAL_HEX.map(h => new THREE.Color(h)); + +// Precompute per-point RGB, indexed by position in data.point_ids. +// If the server attached labels + label_kind, color by that (ramp for +// continuous, palette for categorical) to match the dataset picker. Points +// with null labels (jitter-added, id >= num_points) stay (0,0,0) = black. +// Falls back to a frame-0-present ordinal ramp when no labels are present. +function buildIdColorsRGB(data) { + const n = data.point_ids.length; + const rgb = new Float32Array(n * 3); + const labels = data.labels || []; + const kind = data.label_kind || null; + const hasRealLabels = kind && labels.some(v => v != null && v !== ''); + + if (hasRealLabels) { + const tmp = new THREE.Color(); + if (kind === 'categorical') { + for (let i = 0; i < n; i++) { + const v = labels[i]; + if (v == null) continue; + const c = CATEGORICAL[((v | 0) % CATEGORICAL.length + CATEGORICAL.length) % CATEGORICAL.length]; + rgb[i * 3 + 0] = c.r; + rgb[i * 3 + 1] = c.g; + rgb[i * 3 + 2] = c.b; + } + } else { + let lo = Infinity, hi = -Infinity; + for (const v of labels) { if (v == null) continue; if (v < lo) lo = v; if (v > hi) hi = v; } + const range = (hi - lo) || 1; + for (let i = 0; i < n; i++) { + const v = labels[i]; + if (v == null) continue; + rampContinuous((v - lo) / range, tmp); + rgb[i * 3 + 0] = tmp.r; + rgb[i * 3 + 1] = tmp.g; + rgb[i * 3 + 2] = tmp.b; + } + } + return rgb; + } + + // Fallback: rainbow-by-ordinal over frame-0-present points. + const frame0 = data.frames[0]; + if (!frame0) return rgb; + const originalPositions = []; + for (let i = 0; i < n; i++) { + if (frame0.x[i] != null && frame0.y[i] != null + && !Number.isNaN(frame0.x[i]) && !Number.isNaN(frame0.y[i])) { + originalPositions.push(i); + } + } + originalPositions.sort((a, b) => data.point_ids[a] - data.point_ids[b]); + const nOrig = originalPositions.length; + const tmp = new THREE.Color(); + for (let k = 0; k < nOrig; k++) { + const t = nOrig > 1 ? k / (nOrig - 1) : 0.5; + rampContinuous(t, tmp); + const idx = originalPositions[k]; + rgb[idx * 3 + 0] = tmp.r; + rgb[idx * 3 + 1] = tmp.g; + rgb[idx * 3 + 2] = tmp.b; + } + return rgb; +} + function readVar(name, fallback) { const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return v || fallback; @@ -87,13 +166,16 @@ function createPanel({ slotId, panelEl, data }) { // non-null points into the prefix and call setDrawRange(0, count). const maxN = data.point_ids.length; const positions = new Float32Array(maxN * 3); + const colors = new Float32Array(maxN * 3); const ids = new Int32Array(maxN); // packed point_id per drawn vertex + const idColorRGB = buildIdColorsRGB(data); const geo = new THREE.BufferGeometry(); geo.setAttribute('position', new THREE.BufferAttribute(positions, 3)); + geo.setAttribute('color', new THREE.BufferAttribute(colors, 3)); geo.setDrawRange(0, 0); - const mat = new THREE.PointsMaterial({ + const matMono = new THREE.PointsMaterial({ size: 6.0, sizeAttenuation: false, map: DISK_TEX, @@ -102,6 +184,19 @@ function createPanel({ slotId, panelEl, data }) { depthWrite: false, color: new THREE.Color(panelAccent(slotId)), }); + const matRainbow = new THREE.PointsMaterial({ + size: 6.0, + sizeAttenuation: false, + map: DISK_TEX, + // Opaque + alpha cutoff rather than alpha blending: with a ramp, many + // overlapping translucent points average out to the middle of the ramp + // (blue + orange → green), washing out the whole cloud. Hard edges keep + // each point's true color visible. + transparent: false, + alphaTest: 0.5, + vertexColors: true, + }); + let mat = matMono; const points = new THREE.Points(geo, mat); scene.add(points); @@ -185,11 +280,26 @@ function createPanel({ slotId, panelEl, data }) { } function applyColorsFromTheme() { - mat.color.set(panelAccent(slotId)); + matMono.color.set(panelAccent(slotId)); hiMat.color.set(highlightColor()); scene.background = new THREE.Color(readPanelBg(panelEl)); } + function setColorMode(mode) { + const next = mode === 'mono' ? matMono : matRainbow; + if (points.material !== next) { + points.material = next; + mat = next; + } + } + + // Copy the precomputed per-id RGB into the packed position `j`. + function writePackedColor(j, i) { + colors[j * 3 + 0] = idColorRGB[i * 3 + 0]; + colors[j * 3 + 1] = idColorRGB[i * 3 + 1]; + colors[j * 3 + 2] = idColorRGB[i * 3 + 2]; + } + // Pack frame `f` into the geometry buffer, skipping null x/y. function setFrame(f) { const frame = data.frames[f]; @@ -204,6 +314,7 @@ function createPanel({ slotId, panelEl, data }) { positions[j * 3 + 0] = x; positions[j * 3 + 1] = y; positions[j * 3 + 2] = 0; + writePackedColor(j, i); packedX[j] = x; packedY[j] = y; packedId[j] = ptIds[i]; @@ -211,6 +322,7 @@ function createPanel({ slotId, panelEl, data }) { } packedN = j; geo.attributes.position.needsUpdate = true; + geo.attributes.color.needsUpdate = true; geo.setDrawRange(0, packedN); applyHighlightForCurrentFrame(); } @@ -241,6 +353,7 @@ function createPanel({ slotId, panelEl, data }) { positions[j * 3 + 0] = xi; positions[j * 3 + 1] = yi; positions[j * 3 + 2] = 0; + writePackedColor(j, i); packedX[j] = xi; packedY[j] = yi; packedId[j] = ptIds[i]; @@ -248,6 +361,7 @@ function createPanel({ slotId, panelEl, data }) { } packedN = j; geo.attributes.position.needsUpdate = true; + geo.attributes.color.needsUpdate = true; geo.setDrawRange(0, packedN); applyHighlightForCurrentFrame(); } @@ -321,7 +435,8 @@ function createPanel({ slotId, panelEl, data }) { function dispose() { geo.dispose(); hiGeo.dispose(); - mat.dispose(); + matMono.dispose(); + matRainbow.dispose(); hiMat.dispose(); renderer.dispose(); if (renderer.domElement.parentNode) { @@ -339,6 +454,7 @@ function createPanel({ slotId, panelEl, data }) { data, setFrame, setFrameInterpolated, + setColorMode, setBounds, setHighlight, resize, @@ -410,7 +526,7 @@ function markParamDiffs(metaA, metaB) { // -------- main ------------------------------------------------------------ async function fetchFrames(stem) { - const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`); + const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`, { cache: 'no-store' }); if (!res.ok) { throw new Error(`${res.status} ${res.statusText}`); } @@ -571,6 +687,13 @@ async function main() { applyU(parseFloat(scrub.value) / SCRUB_MAX); }); + function applyColorMode() { + const mode = colorSel.value; + for (const p of Object.values(panels)) p?.setColorMode(mode); + } + colorSel.addEventListener('change', applyColorMode); + applyColorMode(); + // ---- linked hover ----------------------------------------------------- function wireHover(pA, pB) { if (!pA) return; diff --git a/app/web/static/style.css b/app/web/static/style.css index e8f559c..21acdbc 100644 --- a/app/web/static/style.css +++ b/app/web/static/style.css @@ -1635,7 +1635,8 @@ button.submit:disabled { background: var(--faint); border-color: var(--faint); c .compare-controls .cc-speed-wrap, .compare-controls .cc-sync-wrap, -.compare-controls .cc-motion-wrap { +.compare-controls .cc-motion-wrap, +.compare-controls .cc-color-wrap { display: inline-flex; align-items: center; gap: 0.35rem; diff --git a/app/web/templates/compare.html b/app/web/templates/compare.html index 3c098b9..d43b485 100644 --- a/app/web/templates/compare.html +++ b/app/web/templates/compare.html @@ -4,7 +4,7 @@ embedding notebook · compare - + - + diff --git a/app/web/templates/index.html b/app/web/templates/index.html index d6d2b9c..967b0e9 100644 --- a/app/web/templates/index.html +++ b/app/web/templates/index.html @@ -4,7 +4,7 @@ embedding notebook - +