compare: color points by their original-dataset label (mono|original toggle)
Server enrichment regenerates the dataset deterministically (random_state=0, matching the flow's _DEFAULT_GENERATOR_KWARGS — the stem's seed drives jitter, not generation) and attaches per-point labels + label_kind to frames.json. Client picks the dataset-picker's scheme: continuous ramp for s_curve/swiss_roll, 8-color categorical palette for blobs/gaussian_quantiles/classification. Jitter- added points (id >= num_points) render black. Rainbow material is opaque with alpha cutoff so overlapping points don't blend to the ramp midpoint. Swiss_roll and swiss_roll_hole collide on generator_path; the plain variant wins for now (kwargs aren't preserved through the flow's metrics.json). Bumped Cache-Control on the frames endpoint so browsers don't cache stale pre-enrichment payloads.
This commit is contained in:
parent
d3f5088233
commit
9277229024
@ -811,12 +811,48 @@ _STEM_RE = re.compile(
|
|||||||
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$"
|
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Map short generator name ("make_blobs") to its DATASET_META entry.
|
||||||
|
# swiss_roll and swiss_roll_hole collide on path; first wins (plain variant).
|
||||||
|
_GEN_TO_META: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for _m in DATASET_META.values():
|
||||||
|
_GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m)
|
||||||
|
|
||||||
|
|
||||||
|
def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Attach per-point class/continuous labels by regenerating the dataset
|
||||||
|
with the same (generator, n_samples, kwargs). The stem's `seed` drives
|
||||||
|
jitter — NOT generator — so we always use random_state=0 to match the
|
||||||
|
flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points)
|
||||||
|
get None so the client renders them as black."""
|
||||||
|
meta = _GEN_TO_META.get(d["meta"].get("generator") or "")
|
||||||
|
if not meta:
|
||||||
|
return d
|
||||||
|
try:
|
||||||
|
mod_path, cls_name = meta["path"].rsplit(".", 1)
|
||||||
|
fn = getattr(importlib.import_module(mod_path), cls_name)
|
||||||
|
N = int(d["meta"]["num_points"])
|
||||||
|
_, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"])
|
||||||
|
out_labels: List[Optional[float]] = []
|
||||||
|
for pid in d["point_ids"]:
|
||||||
|
if isinstance(pid, int) and 0 <= pid < N:
|
||||||
|
v = gen_labels[pid]
|
||||||
|
out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None)
|
||||||
|
else:
|
||||||
|
out_labels.append(None)
|
||||||
|
d["labels"] = out_labels
|
||||||
|
d["label_kind"] = meta["kind"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _cached_frames(stem: str) -> str:
|
def _cached_frames(stem: str) -> str:
|
||||||
"""Parse <stem>.html and return the frames dict as a JSON string."""
|
"""Parse <stem>.html and return the frames dict as a JSON string."""
|
||||||
html = FIGS_DIR / f"{stem}.html"
|
html = FIGS_DIR / f"{stem}.html"
|
||||||
return json.dumps(parse_plotly_run(html), separators=(",", ":"))
|
d = parse_plotly_run(html)
|
||||||
|
d = _enrich_with_labels(d)
|
||||||
|
return json.dumps(d, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/runs/{stem}/frames.json")
|
@app.get("/api/runs/{stem}/frames.json")
|
||||||
@ -830,7 +866,11 @@ async def run_frames(stem: str) -> Response:
|
|||||||
payload = _cached_frames(stem)
|
payload = _cached_frames(stem)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, f"parse failed: {e}")
|
raise HTTPException(500, f"parse failed: {e}")
|
||||||
return Response(content=payload, media_type="application/json")
|
return Response(
|
||||||
|
content=payload,
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Cache-Control": "no-cache"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/compare", response_class=HTMLResponse)
|
@app.get("/compare", response_class=HTMLResponse)
|
||||||
|
|||||||
@ -22,6 +22,7 @@ const scrub = document.getElementById('cc-scrub');
|
|||||||
const speedSel = document.getElementById('cc-speed');
|
const speedSel = document.getElementById('cc-speed');
|
||||||
const syncSel = document.getElementById('cc-sync');
|
const syncSel = document.getElementById('cc-sync');
|
||||||
const motionSel = document.getElementById('cc-motion');
|
const motionSel = document.getElementById('cc-motion');
|
||||||
|
const colorSel = document.getElementById('cc-color');
|
||||||
const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]');
|
const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]');
|
||||||
const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]');
|
const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]');
|
||||||
|
|
||||||
@ -48,6 +49,84 @@ function makeDiskTexture(size = 64) {
|
|||||||
|
|
||||||
const DISK_TEX = makeDiskTexture();
|
const DISK_TEX = makeDiskTexture();
|
||||||
|
|
||||||
|
// Continuous ramp (blue → orange), matching dataset-picker.js exactly.
|
||||||
|
function rampContinuous(t, out) {
|
||||||
|
const hue = (1 - t) * 215 + t * 28;
|
||||||
|
const sat = 0.62;
|
||||||
|
const lit = 0.50 + (t - 0.5) * 0.08;
|
||||||
|
return (out || new THREE.Color()).setHSL(hue / 360, sat, lit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8-color categorical palette — same hex list as dataset-picker.js.
|
||||||
|
const CATEGORICAL_HEX = [
|
||||||
|
'#1f4e5f', '#c97b3f', '#8b5a9f', '#5a8560',
|
||||||
|
'#c74a5e', '#6b7d8f', '#b89f51', '#4a6fa5',
|
||||||
|
];
|
||||||
|
const CATEGORICAL = CATEGORICAL_HEX.map(h => new THREE.Color(h));
|
||||||
|
|
||||||
|
// Precompute per-point RGB, indexed by position in data.point_ids.
|
||||||
|
// If the server attached labels + label_kind, color by that (ramp for
|
||||||
|
// continuous, palette for categorical) to match the dataset picker. Points
|
||||||
|
// with null labels (jitter-added, id >= num_points) stay (0,0,0) = black.
|
||||||
|
// Falls back to a frame-0-present ordinal ramp when no labels are present.
|
||||||
|
function buildIdColorsRGB(data) {
|
||||||
|
const n = data.point_ids.length;
|
||||||
|
const rgb = new Float32Array(n * 3);
|
||||||
|
const labels = data.labels || [];
|
||||||
|
const kind = data.label_kind || null;
|
||||||
|
const hasRealLabels = kind && labels.some(v => v != null && v !== '');
|
||||||
|
|
||||||
|
if (hasRealLabels) {
|
||||||
|
const tmp = new THREE.Color();
|
||||||
|
if (kind === 'categorical') {
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
const v = labels[i];
|
||||||
|
if (v == null) continue;
|
||||||
|
const c = CATEGORICAL[((v | 0) % CATEGORICAL.length + CATEGORICAL.length) % CATEGORICAL.length];
|
||||||
|
rgb[i * 3 + 0] = c.r;
|
||||||
|
rgb[i * 3 + 1] = c.g;
|
||||||
|
rgb[i * 3 + 2] = c.b;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let lo = Infinity, hi = -Infinity;
|
||||||
|
for (const v of labels) { if (v == null) continue; if (v < lo) lo = v; if (v > hi) hi = v; }
|
||||||
|
const range = (hi - lo) || 1;
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
const v = labels[i];
|
||||||
|
if (v == null) continue;
|
||||||
|
rampContinuous((v - lo) / range, tmp);
|
||||||
|
rgb[i * 3 + 0] = tmp.r;
|
||||||
|
rgb[i * 3 + 1] = tmp.g;
|
||||||
|
rgb[i * 3 + 2] = tmp.b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rgb;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: rainbow-by-ordinal over frame-0-present points.
|
||||||
|
const frame0 = data.frames[0];
|
||||||
|
if (!frame0) return rgb;
|
||||||
|
const originalPositions = [];
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
if (frame0.x[i] != null && frame0.y[i] != null
|
||||||
|
&& !Number.isNaN(frame0.x[i]) && !Number.isNaN(frame0.y[i])) {
|
||||||
|
originalPositions.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
originalPositions.sort((a, b) => data.point_ids[a] - data.point_ids[b]);
|
||||||
|
const nOrig = originalPositions.length;
|
||||||
|
const tmp = new THREE.Color();
|
||||||
|
for (let k = 0; k < nOrig; k++) {
|
||||||
|
const t = nOrig > 1 ? k / (nOrig - 1) : 0.5;
|
||||||
|
rampContinuous(t, tmp);
|
||||||
|
const idx = originalPositions[k];
|
||||||
|
rgb[idx * 3 + 0] = tmp.r;
|
||||||
|
rgb[idx * 3 + 1] = tmp.g;
|
||||||
|
rgb[idx * 3 + 2] = tmp.b;
|
||||||
|
}
|
||||||
|
return rgb;
|
||||||
|
}
|
||||||
|
|
||||||
function readVar(name, fallback) {
|
function readVar(name, fallback) {
|
||||||
const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||||
return v || fallback;
|
return v || fallback;
|
||||||
@ -87,13 +166,16 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
// non-null points into the prefix and call setDrawRange(0, count).
|
// non-null points into the prefix and call setDrawRange(0, count).
|
||||||
const maxN = data.point_ids.length;
|
const maxN = data.point_ids.length;
|
||||||
const positions = new Float32Array(maxN * 3);
|
const positions = new Float32Array(maxN * 3);
|
||||||
|
const colors = new Float32Array(maxN * 3);
|
||||||
const ids = new Int32Array(maxN); // packed point_id per drawn vertex
|
const ids = new Int32Array(maxN); // packed point_id per drawn vertex
|
||||||
|
const idColorRGB = buildIdColorsRGB(data);
|
||||||
|
|
||||||
const geo = new THREE.BufferGeometry();
|
const geo = new THREE.BufferGeometry();
|
||||||
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
||||||
|
geo.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
||||||
geo.setDrawRange(0, 0);
|
geo.setDrawRange(0, 0);
|
||||||
|
|
||||||
const mat = new THREE.PointsMaterial({
|
const matMono = new THREE.PointsMaterial({
|
||||||
size: 6.0,
|
size: 6.0,
|
||||||
sizeAttenuation: false,
|
sizeAttenuation: false,
|
||||||
map: DISK_TEX,
|
map: DISK_TEX,
|
||||||
@ -102,6 +184,19 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
depthWrite: false,
|
depthWrite: false,
|
||||||
color: new THREE.Color(panelAccent(slotId)),
|
color: new THREE.Color(panelAccent(slotId)),
|
||||||
});
|
});
|
||||||
|
const matRainbow = new THREE.PointsMaterial({
|
||||||
|
size: 6.0,
|
||||||
|
sizeAttenuation: false,
|
||||||
|
map: DISK_TEX,
|
||||||
|
// Opaque + alpha cutoff rather than alpha blending: with a ramp, many
|
||||||
|
// overlapping translucent points average out to the middle of the ramp
|
||||||
|
// (blue + orange → green), washing out the whole cloud. Hard edges keep
|
||||||
|
// each point's true color visible.
|
||||||
|
transparent: false,
|
||||||
|
alphaTest: 0.5,
|
||||||
|
vertexColors: true,
|
||||||
|
});
|
||||||
|
let mat = matMono;
|
||||||
const points = new THREE.Points(geo, mat);
|
const points = new THREE.Points(geo, mat);
|
||||||
scene.add(points);
|
scene.add(points);
|
||||||
|
|
||||||
@ -185,11 +280,26 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function applyColorsFromTheme() {
|
function applyColorsFromTheme() {
|
||||||
mat.color.set(panelAccent(slotId));
|
matMono.color.set(panelAccent(slotId));
|
||||||
hiMat.color.set(highlightColor());
|
hiMat.color.set(highlightColor());
|
||||||
scene.background = new THREE.Color(readPanelBg(panelEl));
|
scene.background = new THREE.Color(readPanelBg(panelEl));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function setColorMode(mode) {
|
||||||
|
const next = mode === 'mono' ? matMono : matRainbow;
|
||||||
|
if (points.material !== next) {
|
||||||
|
points.material = next;
|
||||||
|
mat = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the precomputed per-id RGB into the packed position `j`.
|
||||||
|
function writePackedColor(j, i) {
|
||||||
|
colors[j * 3 + 0] = idColorRGB[i * 3 + 0];
|
||||||
|
colors[j * 3 + 1] = idColorRGB[i * 3 + 1];
|
||||||
|
colors[j * 3 + 2] = idColorRGB[i * 3 + 2];
|
||||||
|
}
|
||||||
|
|
||||||
// Pack frame `f` into the geometry buffer, skipping null x/y.
|
// Pack frame `f` into the geometry buffer, skipping null x/y.
|
||||||
function setFrame(f) {
|
function setFrame(f) {
|
||||||
const frame = data.frames[f];
|
const frame = data.frames[f];
|
||||||
@ -204,6 +314,7 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
positions[j * 3 + 0] = x;
|
positions[j * 3 + 0] = x;
|
||||||
positions[j * 3 + 1] = y;
|
positions[j * 3 + 1] = y;
|
||||||
positions[j * 3 + 2] = 0;
|
positions[j * 3 + 2] = 0;
|
||||||
|
writePackedColor(j, i);
|
||||||
packedX[j] = x;
|
packedX[j] = x;
|
||||||
packedY[j] = y;
|
packedY[j] = y;
|
||||||
packedId[j] = ptIds[i];
|
packedId[j] = ptIds[i];
|
||||||
@ -211,6 +322,7 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
}
|
}
|
||||||
packedN = j;
|
packedN = j;
|
||||||
geo.attributes.position.needsUpdate = true;
|
geo.attributes.position.needsUpdate = true;
|
||||||
|
geo.attributes.color.needsUpdate = true;
|
||||||
geo.setDrawRange(0, packedN);
|
geo.setDrawRange(0, packedN);
|
||||||
applyHighlightForCurrentFrame();
|
applyHighlightForCurrentFrame();
|
||||||
}
|
}
|
||||||
@ -241,6 +353,7 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
positions[j * 3 + 0] = xi;
|
positions[j * 3 + 0] = xi;
|
||||||
positions[j * 3 + 1] = yi;
|
positions[j * 3 + 1] = yi;
|
||||||
positions[j * 3 + 2] = 0;
|
positions[j * 3 + 2] = 0;
|
||||||
|
writePackedColor(j, i);
|
||||||
packedX[j] = xi;
|
packedX[j] = xi;
|
||||||
packedY[j] = yi;
|
packedY[j] = yi;
|
||||||
packedId[j] = ptIds[i];
|
packedId[j] = ptIds[i];
|
||||||
@ -248,6 +361,7 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
}
|
}
|
||||||
packedN = j;
|
packedN = j;
|
||||||
geo.attributes.position.needsUpdate = true;
|
geo.attributes.position.needsUpdate = true;
|
||||||
|
geo.attributes.color.needsUpdate = true;
|
||||||
geo.setDrawRange(0, packedN);
|
geo.setDrawRange(0, packedN);
|
||||||
applyHighlightForCurrentFrame();
|
applyHighlightForCurrentFrame();
|
||||||
}
|
}
|
||||||
@ -321,7 +435,8 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
function dispose() {
|
function dispose() {
|
||||||
geo.dispose();
|
geo.dispose();
|
||||||
hiGeo.dispose();
|
hiGeo.dispose();
|
||||||
mat.dispose();
|
matMono.dispose();
|
||||||
|
matRainbow.dispose();
|
||||||
hiMat.dispose();
|
hiMat.dispose();
|
||||||
renderer.dispose();
|
renderer.dispose();
|
||||||
if (renderer.domElement.parentNode) {
|
if (renderer.domElement.parentNode) {
|
||||||
@ -339,6 +454,7 @@ function createPanel({ slotId, panelEl, data }) {
|
|||||||
data,
|
data,
|
||||||
setFrame,
|
setFrame,
|
||||||
setFrameInterpolated,
|
setFrameInterpolated,
|
||||||
|
setColorMode,
|
||||||
setBounds,
|
setBounds,
|
||||||
setHighlight,
|
setHighlight,
|
||||||
resize,
|
resize,
|
||||||
@ -410,7 +526,7 @@ function markParamDiffs(metaA, metaB) {
|
|||||||
// -------- main ------------------------------------------------------------
|
// -------- main ------------------------------------------------------------
|
||||||
|
|
||||||
async function fetchFrames(stem) {
|
async function fetchFrames(stem) {
|
||||||
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`);
|
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`, { cache: 'no-store' });
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
throw new Error(`${res.status} ${res.statusText}`);
|
throw new Error(`${res.status} ${res.statusText}`);
|
||||||
}
|
}
|
||||||
@ -571,6 +687,13 @@ async function main() {
|
|||||||
applyU(parseFloat(scrub.value) / SCRUB_MAX);
|
applyU(parseFloat(scrub.value) / SCRUB_MAX);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
function applyColorMode() {
|
||||||
|
const mode = colorSel.value;
|
||||||
|
for (const p of Object.values(panels)) p?.setColorMode(mode);
|
||||||
|
}
|
||||||
|
colorSel.addEventListener('change', applyColorMode);
|
||||||
|
applyColorMode();
|
||||||
|
|
||||||
// ---- linked hover -----------------------------------------------------
|
// ---- linked hover -----------------------------------------------------
|
||||||
function wireHover(pA, pB) {
|
function wireHover(pA, pB) {
|
||||||
if (!pA) return;
|
if (!pA) return;
|
||||||
|
|||||||
@ -1635,7 +1635,8 @@ button.submit:disabled { background: var(--faint); border-color: var(--faint); c
|
|||||||
|
|
||||||
.compare-controls .cc-speed-wrap,
|
.compare-controls .cc-speed-wrap,
|
||||||
.compare-controls .cc-sync-wrap,
|
.compare-controls .cc-sync-wrap,
|
||||||
.compare-controls .cc-motion-wrap {
|
.compare-controls .cc-motion-wrap,
|
||||||
|
.compare-controls .cc-color-wrap {
|
||||||
display: inline-flex;
|
display: inline-flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
gap: 0.35rem;
|
gap: 0.35rem;
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||||
<title>embedding notebook · compare</title>
|
<title>embedding notebook · compare</title>
|
||||||
<link rel="stylesheet" href="/static/style.css?v=25" />
|
<link rel="stylesheet" href="/static/style.css?v=26" />
|
||||||
<script type="importmap">
|
<script type="importmap">
|
||||||
{
|
{
|
||||||
"imports": {
|
"imports": {
|
||||||
@ -67,6 +67,14 @@
|
|||||||
</select>
|
</select>
|
||||||
</label>
|
</label>
|
||||||
|
|
||||||
|
<label class="cc-color-wrap">
|
||||||
|
<span class="cc-lbl">color</span>
|
||||||
|
<select class="cc-color" id="cc-color">
|
||||||
|
<option value="mono" selected>mono</option>
|
||||||
|
<option value="original">original</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
|
||||||
<label class="cc-sync-wrap">
|
<label class="cc-sync-wrap">
|
||||||
<span class="cc-lbl">axes</span>
|
<span class="cc-lbl">axes</span>
|
||||||
<select class="cc-sync" id="cc-sync">
|
<select class="cc-sync" id="cc-sync">
|
||||||
@ -111,6 +119,6 @@
|
|||||||
</section>
|
</section>
|
||||||
|
|
||||||
<script src="/static/theme.js?v=11"></script>
|
<script src="/static/theme.js?v=11"></script>
|
||||||
<script type="module" src="/static/compare.js?v=7"></script>
|
<script type="module" src="/static/compare.js?v=11"></script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||||
<title>embedding notebook</title>
|
<title>embedding notebook</title>
|
||||||
<link rel="stylesheet" href="/static/style.css?v=25" />
|
<link rel="stylesheet" href="/static/style.css?v=26" />
|
||||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||||
<script type="importmap">
|
<script type="importmap">
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user