compare: color points by their original-dataset label (mono|original toggle)

Server enrichment regenerates the dataset deterministically (random_state=0,
matching the flow's _DEFAULT_GENERATOR_KWARGS — the stem's seed drives jitter,
not generation) and attaches per-point labels + label_kind to frames.json.

Client picks the dataset-picker's scheme: continuous ramp for s_curve/swiss_roll,
8-color categorical palette for blobs/gaussian_quantiles/classification. Jitter-
added points (id >= num_points) render black. Rainbow material is opaque with
alpha cutoff so overlapping points don't blend to the ramp midpoint.

Swiss_roll and swiss_roll_hole collide on generator_path; the plain variant
wins for now (kwargs aren't preserved through the flow's metrics.json).

Bumped Cache-Control on the frames endpoint so browsers don't cache stale
pre-enrichment payloads.
This commit is contained in:
Michael Pilosov 2026-04-22 15:29:03 -06:00
parent d3f5088233
commit 9277229024
5 changed files with 182 additions and 10 deletions

View File

@ -811,12 +811,48 @@ _STEM_RE = re.compile(
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$" r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$"
) )
# Map short generator name ("make_blobs") to its DATASET_META entry.
# swiss_roll and swiss_roll_hole collide on path; first wins (plain variant).
_GEN_TO_META: Dict[str, Dict[str, Any]] = {}
for _m in DATASET_META.values():
_GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m)
def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
"""Attach per-point class/continuous labels by regenerating the dataset
with the same (generator, n_samples, kwargs). The stem's `seed` drives
jitter NOT generator so we always use random_state=0 to match the
flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points)
get None so the client renders them as black."""
meta = _GEN_TO_META.get(d["meta"].get("generator") or "")
if not meta:
return d
try:
mod_path, cls_name = meta["path"].rsplit(".", 1)
fn = getattr(importlib.import_module(mod_path), cls_name)
N = int(d["meta"]["num_points"])
_, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"])
out_labels: List[Optional[float]] = []
for pid in d["point_ids"]:
if isinstance(pid, int) and 0 <= pid < N:
v = gen_labels[pid]
out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None)
else:
out_labels.append(None)
d["labels"] = out_labels
d["label_kind"] = meta["kind"]
except Exception:
pass
return d
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _cached_frames(stem: str) -> str: def _cached_frames(stem: str) -> str:
"""Parse <stem>.html and return the frames dict as a JSON string.""" """Parse <stem>.html and return the frames dict as a JSON string."""
html = FIGS_DIR / f"{stem}.html" html = FIGS_DIR / f"{stem}.html"
return json.dumps(parse_plotly_run(html), separators=(",", ":")) d = parse_plotly_run(html)
d = _enrich_with_labels(d)
return json.dumps(d, separators=(",", ":"))
@app.get("/api/runs/{stem}/frames.json") @app.get("/api/runs/{stem}/frames.json")
@ -830,7 +866,11 @@ async def run_frames(stem: str) -> Response:
payload = _cached_frames(stem) payload = _cached_frames(stem)
except Exception as e: except Exception as e:
raise HTTPException(500, f"parse failed: {e}") raise HTTPException(500, f"parse failed: {e}")
return Response(content=payload, media_type="application/json") return Response(
content=payload,
media_type="application/json",
headers={"Cache-Control": "no-cache"},
)
@app.get("/compare", response_class=HTMLResponse) @app.get("/compare", response_class=HTMLResponse)

View File

@ -22,6 +22,7 @@ const scrub = document.getElementById('cc-scrub');
const speedSel = document.getElementById('cc-speed'); const speedSel = document.getElementById('cc-speed');
const syncSel = document.getElementById('cc-sync'); const syncSel = document.getElementById('cc-sync');
const motionSel = document.getElementById('cc-motion'); const motionSel = document.getElementById('cc-motion');
const colorSel = document.getElementById('cc-color');
const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]'); const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]');
const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]'); const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]');
@ -48,6 +49,84 @@ function makeDiskTexture(size = 64) {
const DISK_TEX = makeDiskTexture(); const DISK_TEX = makeDiskTexture();
// Continuous ramp (blue → orange), matching dataset-picker.js exactly.
function rampContinuous(t, out) {
const hue = (1 - t) * 215 + t * 28;
const sat = 0.62;
const lit = 0.50 + (t - 0.5) * 0.08;
return (out || new THREE.Color()).setHSL(hue / 360, sat, lit);
}
// 8-color categorical palette — same hex list as dataset-picker.js.
const CATEGORICAL_HEX = [
'#1f4e5f', '#c97b3f', '#8b5a9f', '#5a8560',
'#c74a5e', '#6b7d8f', '#b89f51', '#4a6fa5',
];
const CATEGORICAL = CATEGORICAL_HEX.map(h => new THREE.Color(h));
// Precompute per-point RGB, indexed by position in data.point_ids.
// If the server attached labels + label_kind, color by that (ramp for
// continuous, palette for categorical) to match the dataset picker. Points
// with null labels (jitter-added, id >= num_points) stay (0,0,0) = black.
// Falls back to a frame-0-present ordinal ramp when no labels are present.
function buildIdColorsRGB(data) {
const n = data.point_ids.length;
const rgb = new Float32Array(n * 3);
const labels = data.labels || [];
const kind = data.label_kind || null;
const hasRealLabels = kind && labels.some(v => v != null && v !== '');
if (hasRealLabels) {
const tmp = new THREE.Color();
if (kind === 'categorical') {
for (let i = 0; i < n; i++) {
const v = labels[i];
if (v == null) continue;
const c = CATEGORICAL[((v | 0) % CATEGORICAL.length + CATEGORICAL.length) % CATEGORICAL.length];
rgb[i * 3 + 0] = c.r;
rgb[i * 3 + 1] = c.g;
rgb[i * 3 + 2] = c.b;
}
} else {
let lo = Infinity, hi = -Infinity;
for (const v of labels) { if (v == null) continue; if (v < lo) lo = v; if (v > hi) hi = v; }
const range = (hi - lo) || 1;
for (let i = 0; i < n; i++) {
const v = labels[i];
if (v == null) continue;
rampContinuous((v - lo) / range, tmp);
rgb[i * 3 + 0] = tmp.r;
rgb[i * 3 + 1] = tmp.g;
rgb[i * 3 + 2] = tmp.b;
}
}
return rgb;
}
// Fallback: rainbow-by-ordinal over frame-0-present points.
const frame0 = data.frames[0];
if (!frame0) return rgb;
const originalPositions = [];
for (let i = 0; i < n; i++) {
if (frame0.x[i] != null && frame0.y[i] != null
&& !Number.isNaN(frame0.x[i]) && !Number.isNaN(frame0.y[i])) {
originalPositions.push(i);
}
}
originalPositions.sort((a, b) => data.point_ids[a] - data.point_ids[b]);
const nOrig = originalPositions.length;
const tmp = new THREE.Color();
for (let k = 0; k < nOrig; k++) {
const t = nOrig > 1 ? k / (nOrig - 1) : 0.5;
rampContinuous(t, tmp);
const idx = originalPositions[k];
rgb[idx * 3 + 0] = tmp.r;
rgb[idx * 3 + 1] = tmp.g;
rgb[idx * 3 + 2] = tmp.b;
}
return rgb;
}
function readVar(name, fallback) { function readVar(name, fallback) {
const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim();
return v || fallback; return v || fallback;
@ -87,13 +166,16 @@ function createPanel({ slotId, panelEl, data }) {
// non-null points into the prefix and call setDrawRange(0, count). // non-null points into the prefix and call setDrawRange(0, count).
const maxN = data.point_ids.length; const maxN = data.point_ids.length;
const positions = new Float32Array(maxN * 3); const positions = new Float32Array(maxN * 3);
const colors = new Float32Array(maxN * 3);
const ids = new Int32Array(maxN); // packed point_id per drawn vertex const ids = new Int32Array(maxN); // packed point_id per drawn vertex
const idColorRGB = buildIdColorsRGB(data);
const geo = new THREE.BufferGeometry(); const geo = new THREE.BufferGeometry();
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3)); geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
geo.setAttribute('color', new THREE.BufferAttribute(colors, 3));
geo.setDrawRange(0, 0); geo.setDrawRange(0, 0);
const mat = new THREE.PointsMaterial({ const matMono = new THREE.PointsMaterial({
size: 6.0, size: 6.0,
sizeAttenuation: false, sizeAttenuation: false,
map: DISK_TEX, map: DISK_TEX,
@ -102,6 +184,19 @@ function createPanel({ slotId, panelEl, data }) {
depthWrite: false, depthWrite: false,
color: new THREE.Color(panelAccent(slotId)), color: new THREE.Color(panelAccent(slotId)),
}); });
const matRainbow = new THREE.PointsMaterial({
size: 6.0,
sizeAttenuation: false,
map: DISK_TEX,
// Opaque + alpha cutoff rather than alpha blending: with a ramp, many
// overlapping translucent points average out to the middle of the ramp
// (blue + orange → green), washing out the whole cloud. Hard edges keep
// each point's true color visible.
transparent: false,
alphaTest: 0.5,
vertexColors: true,
});
let mat = matMono;
const points = new THREE.Points(geo, mat); const points = new THREE.Points(geo, mat);
scene.add(points); scene.add(points);
@ -185,11 +280,26 @@ function createPanel({ slotId, panelEl, data }) {
} }
function applyColorsFromTheme() { function applyColorsFromTheme() {
mat.color.set(panelAccent(slotId)); matMono.color.set(panelAccent(slotId));
hiMat.color.set(highlightColor()); hiMat.color.set(highlightColor());
scene.background = new THREE.Color(readPanelBg(panelEl)); scene.background = new THREE.Color(readPanelBg(panelEl));
} }
function setColorMode(mode) {
const next = mode === 'mono' ? matMono : matRainbow;
if (points.material !== next) {
points.material = next;
mat = next;
}
}
// Copy the precomputed per-id RGB into the packed position `j`.
function writePackedColor(j, i) {
colors[j * 3 + 0] = idColorRGB[i * 3 + 0];
colors[j * 3 + 1] = idColorRGB[i * 3 + 1];
colors[j * 3 + 2] = idColorRGB[i * 3 + 2];
}
// Pack frame `f` into the geometry buffer, skipping null x/y. // Pack frame `f` into the geometry buffer, skipping null x/y.
function setFrame(f) { function setFrame(f) {
const frame = data.frames[f]; const frame = data.frames[f];
@ -204,6 +314,7 @@ function createPanel({ slotId, panelEl, data }) {
positions[j * 3 + 0] = x; positions[j * 3 + 0] = x;
positions[j * 3 + 1] = y; positions[j * 3 + 1] = y;
positions[j * 3 + 2] = 0; positions[j * 3 + 2] = 0;
writePackedColor(j, i);
packedX[j] = x; packedX[j] = x;
packedY[j] = y; packedY[j] = y;
packedId[j] = ptIds[i]; packedId[j] = ptIds[i];
@ -211,6 +322,7 @@ function createPanel({ slotId, panelEl, data }) {
} }
packedN = j; packedN = j;
geo.attributes.position.needsUpdate = true; geo.attributes.position.needsUpdate = true;
geo.attributes.color.needsUpdate = true;
geo.setDrawRange(0, packedN); geo.setDrawRange(0, packedN);
applyHighlightForCurrentFrame(); applyHighlightForCurrentFrame();
} }
@ -241,6 +353,7 @@ function createPanel({ slotId, panelEl, data }) {
positions[j * 3 + 0] = xi; positions[j * 3 + 0] = xi;
positions[j * 3 + 1] = yi; positions[j * 3 + 1] = yi;
positions[j * 3 + 2] = 0; positions[j * 3 + 2] = 0;
writePackedColor(j, i);
packedX[j] = xi; packedX[j] = xi;
packedY[j] = yi; packedY[j] = yi;
packedId[j] = ptIds[i]; packedId[j] = ptIds[i];
@ -248,6 +361,7 @@ function createPanel({ slotId, panelEl, data }) {
} }
packedN = j; packedN = j;
geo.attributes.position.needsUpdate = true; geo.attributes.position.needsUpdate = true;
geo.attributes.color.needsUpdate = true;
geo.setDrawRange(0, packedN); geo.setDrawRange(0, packedN);
applyHighlightForCurrentFrame(); applyHighlightForCurrentFrame();
} }
@ -321,7 +435,8 @@ function createPanel({ slotId, panelEl, data }) {
function dispose() { function dispose() {
geo.dispose(); geo.dispose();
hiGeo.dispose(); hiGeo.dispose();
mat.dispose(); matMono.dispose();
matRainbow.dispose();
hiMat.dispose(); hiMat.dispose();
renderer.dispose(); renderer.dispose();
if (renderer.domElement.parentNode) { if (renderer.domElement.parentNode) {
@ -339,6 +454,7 @@ function createPanel({ slotId, panelEl, data }) {
data, data,
setFrame, setFrame,
setFrameInterpolated, setFrameInterpolated,
setColorMode,
setBounds, setBounds,
setHighlight, setHighlight,
resize, resize,
@ -410,7 +526,7 @@ function markParamDiffs(metaA, metaB) {
// -------- main ------------------------------------------------------------ // -------- main ------------------------------------------------------------
async function fetchFrames(stem) { async function fetchFrames(stem) {
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`); const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`, { cache: 'no-store' });
if (!res.ok) { if (!res.ok) {
throw new Error(`${res.status} ${res.statusText}`); throw new Error(`${res.status} ${res.statusText}`);
} }
@ -571,6 +687,13 @@ async function main() {
applyU(parseFloat(scrub.value) / SCRUB_MAX); applyU(parseFloat(scrub.value) / SCRUB_MAX);
}); });
function applyColorMode() {
const mode = colorSel.value;
for (const p of Object.values(panels)) p?.setColorMode(mode);
}
colorSel.addEventListener('change', applyColorMode);
applyColorMode();
// ---- linked hover ----------------------------------------------------- // ---- linked hover -----------------------------------------------------
function wireHover(pA, pB) { function wireHover(pA, pB) {
if (!pA) return; if (!pA) return;

View File

@ -1635,7 +1635,8 @@ button.submit:disabled { background: var(--faint); border-color: var(--faint); c
.compare-controls .cc-speed-wrap, .compare-controls .cc-speed-wrap,
.compare-controls .cc-sync-wrap, .compare-controls .cc-sync-wrap,
.compare-controls .cc-motion-wrap { .compare-controls .cc-motion-wrap,
.compare-controls .cc-color-wrap {
display: inline-flex; display: inline-flex;
align-items: center; align-items: center;
gap: 0.35rem; gap: 0.35rem;

View File

@ -4,7 +4,7 @@
<meta charset="utf-8" /> <meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" /> <meta name="viewport" content="width=device-width,initial-scale=1" />
<title>embedding notebook &middot; compare</title> <title>embedding notebook &middot; compare</title>
<link rel="stylesheet" href="/static/style.css?v=25" /> <link rel="stylesheet" href="/static/style.css?v=26" />
<script type="importmap"> <script type="importmap">
{ {
"imports": { "imports": {
@ -67,6 +67,14 @@
</select> </select>
</label> </label>
<label class="cc-color-wrap">
<span class="cc-lbl">color</span>
<select class="cc-color" id="cc-color">
<option value="mono" selected>mono</option>
<option value="original">original</option>
</select>
</label>
<label class="cc-sync-wrap"> <label class="cc-sync-wrap">
<span class="cc-lbl">axes</span> <span class="cc-lbl">axes</span>
<select class="cc-sync" id="cc-sync"> <select class="cc-sync" id="cc-sync">
@ -111,6 +119,6 @@
</section> </section>
<script src="/static/theme.js?v=11"></script> <script src="/static/theme.js?v=11"></script>
<script type="module" src="/static/compare.js?v=7"></script> <script type="module" src="/static/compare.js?v=11"></script>
</body> </body>
</html> </html>

View File

@ -4,7 +4,7 @@
<meta charset="utf-8" /> <meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" /> <meta name="viewport" content="width=device-width,initial-scale=1" />
<title>embedding notebook</title> <title>embedding notebook</title>
<link rel="stylesheet" href="/static/style.css?v=25" /> <link rel="stylesheet" href="/static/style.css?v=26" />
<script src="https://unpkg.com/htmx.org@2.0.4"></script> <script src="https://unpkg.com/htmx.org@2.0.4"></script>
<script type="importmap"> <script type="importmap">
{ {