compare: color points by their original-dataset label (mono|original toggle)
Server enrichment regenerates the dataset deterministically (random_state=0, matching the flow's _DEFAULT_GENERATOR_KWARGS — the stem's seed drives jitter, not generation) and attaches per-point labels + label_kind to frames.json. Client picks the dataset-picker's scheme: continuous ramp for s_curve/swiss_roll, 8-color categorical palette for blobs/gaussian_quantiles/classification. Jitter- added points (id >= num_points) render black. Rainbow material is opaque with alpha cutoff so overlapping points don't blend to the ramp midpoint. Swiss_roll and swiss_roll_hole collide on generator_path; the plain variant wins for now (kwargs aren't preserved through the flow's metrics.json). Bumped Cache-Control on the frames endpoint so browsers don't cache stale pre-enrichment payloads.
This commit is contained in:
parent
d3f5088233
commit
9277229024
@ -811,12 +811,48 @@ _STEM_RE = re.compile(
|
||||
r"^make_[A-Za-z_]+?_[A-Za-z]+_N\d+_T\d+_J[\d.]+_s\d+$"
|
||||
)
|
||||
|
||||
# Map short generator name ("make_blobs") to its DATASET_META entry.
|
||||
# swiss_roll and swiss_roll_hole collide on path; first wins (plain variant).
|
||||
_GEN_TO_META: Dict[str, Dict[str, Any]] = {}
|
||||
for _m in DATASET_META.values():
|
||||
_GEN_TO_META.setdefault(_m["path"].rsplit(".", 1)[-1], _m)
|
||||
|
||||
|
||||
def _enrich_with_labels(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Attach per-point class/continuous labels by regenerating the dataset
|
||||
with the same (generator, n_samples, kwargs). The stem's `seed` drives
|
||||
jitter — NOT generator — so we always use random_state=0 to match the
|
||||
flow's _DEFAULT_GENERATOR_KWARGS. Jitter-added points (id >= num_points)
|
||||
get None so the client renders them as black."""
|
||||
meta = _GEN_TO_META.get(d["meta"].get("generator") or "")
|
||||
if not meta:
|
||||
return d
|
||||
try:
|
||||
mod_path, cls_name = meta["path"].rsplit(".", 1)
|
||||
fn = getattr(importlib.import_module(mod_path), cls_name)
|
||||
N = int(d["meta"]["num_points"])
|
||||
_, gen_labels = fn(n_samples=N, random_state=0, **meta["kwargs"])
|
||||
out_labels: List[Optional[float]] = []
|
||||
for pid in d["point_ids"]:
|
||||
if isinstance(pid, int) and 0 <= pid < N:
|
||||
v = gen_labels[pid]
|
||||
out_labels.append(float(v) if hasattr(v, "item") or isinstance(v, (int, float)) else None)
|
||||
else:
|
||||
out_labels.append(None)
|
||||
d["labels"] = out_labels
|
||||
d["label_kind"] = meta["kind"]
|
||||
except Exception:
|
||||
pass
|
||||
return d
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _cached_frames(stem: str) -> str:
|
||||
"""Parse <stem>.html and return the frames dict as a JSON string."""
|
||||
html = FIGS_DIR / f"{stem}.html"
|
||||
return json.dumps(parse_plotly_run(html), separators=(",", ":"))
|
||||
d = parse_plotly_run(html)
|
||||
d = _enrich_with_labels(d)
|
||||
return json.dumps(d, separators=(",", ":"))
|
||||
|
||||
|
||||
@app.get("/api/runs/{stem}/frames.json")
|
||||
@ -830,7 +866,11 @@ async def run_frames(stem: str) -> Response:
|
||||
payload = _cached_frames(stem)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"parse failed: {e}")
|
||||
return Response(content=payload, media_type="application/json")
|
||||
return Response(
|
||||
content=payload,
|
||||
media_type="application/json",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/compare", response_class=HTMLResponse)
|
||||
|
||||
@ -22,6 +22,7 @@ const scrub = document.getElementById('cc-scrub');
|
||||
const speedSel = document.getElementById('cc-speed');
|
||||
const syncSel = document.getElementById('cc-sync');
|
||||
const motionSel = document.getElementById('cc-motion');
|
||||
const colorSel = document.getElementById('cc-color');
|
||||
const timeAEl = document.getElementById('cc-time').querySelector('[data-role="time-a"]');
|
||||
const timeBEl = document.getElementById('cc-time').querySelector('[data-role="time-b"]');
|
||||
|
||||
@ -48,6 +49,84 @@ function makeDiskTexture(size = 64) {
|
||||
|
||||
const DISK_TEX = makeDiskTexture();
|
||||
|
||||
// Continuous ramp (blue → orange), matching dataset-picker.js exactly.
|
||||
function rampContinuous(t, out) {
|
||||
const hue = (1 - t) * 215 + t * 28;
|
||||
const sat = 0.62;
|
||||
const lit = 0.50 + (t - 0.5) * 0.08;
|
||||
return (out || new THREE.Color()).setHSL(hue / 360, sat, lit);
|
||||
}
|
||||
|
||||
// 8-color categorical palette — same hex list as dataset-picker.js.
|
||||
const CATEGORICAL_HEX = [
|
||||
'#1f4e5f', '#c97b3f', '#8b5a9f', '#5a8560',
|
||||
'#c74a5e', '#6b7d8f', '#b89f51', '#4a6fa5',
|
||||
];
|
||||
const CATEGORICAL = CATEGORICAL_HEX.map(h => new THREE.Color(h));
|
||||
|
||||
// Precompute per-point RGB, indexed by position in data.point_ids.
|
||||
// If the server attached labels + label_kind, color by that (ramp for
|
||||
// continuous, palette for categorical) to match the dataset picker. Points
|
||||
// with null labels (jitter-added, id >= num_points) stay (0,0,0) = black.
|
||||
// Falls back to a frame-0-present ordinal ramp when no labels are present.
|
||||
function buildIdColorsRGB(data) {
|
||||
const n = data.point_ids.length;
|
||||
const rgb = new Float32Array(n * 3);
|
||||
const labels = data.labels || [];
|
||||
const kind = data.label_kind || null;
|
||||
const hasRealLabels = kind && labels.some(v => v != null && v !== '');
|
||||
|
||||
if (hasRealLabels) {
|
||||
const tmp = new THREE.Color();
|
||||
if (kind === 'categorical') {
|
||||
for (let i = 0; i < n; i++) {
|
||||
const v = labels[i];
|
||||
if (v == null) continue;
|
||||
const c = CATEGORICAL[((v | 0) % CATEGORICAL.length + CATEGORICAL.length) % CATEGORICAL.length];
|
||||
rgb[i * 3 + 0] = c.r;
|
||||
rgb[i * 3 + 1] = c.g;
|
||||
rgb[i * 3 + 2] = c.b;
|
||||
}
|
||||
} else {
|
||||
let lo = Infinity, hi = -Infinity;
|
||||
for (const v of labels) { if (v == null) continue; if (v < lo) lo = v; if (v > hi) hi = v; }
|
||||
const range = (hi - lo) || 1;
|
||||
for (let i = 0; i < n; i++) {
|
||||
const v = labels[i];
|
||||
if (v == null) continue;
|
||||
rampContinuous((v - lo) / range, tmp);
|
||||
rgb[i * 3 + 0] = tmp.r;
|
||||
rgb[i * 3 + 1] = tmp.g;
|
||||
rgb[i * 3 + 2] = tmp.b;
|
||||
}
|
||||
}
|
||||
return rgb;
|
||||
}
|
||||
|
||||
// Fallback: rainbow-by-ordinal over frame-0-present points.
|
||||
const frame0 = data.frames[0];
|
||||
if (!frame0) return rgb;
|
||||
const originalPositions = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
if (frame0.x[i] != null && frame0.y[i] != null
|
||||
&& !Number.isNaN(frame0.x[i]) && !Number.isNaN(frame0.y[i])) {
|
||||
originalPositions.push(i);
|
||||
}
|
||||
}
|
||||
originalPositions.sort((a, b) => data.point_ids[a] - data.point_ids[b]);
|
||||
const nOrig = originalPositions.length;
|
||||
const tmp = new THREE.Color();
|
||||
for (let k = 0; k < nOrig; k++) {
|
||||
const t = nOrig > 1 ? k / (nOrig - 1) : 0.5;
|
||||
rampContinuous(t, tmp);
|
||||
const idx = originalPositions[k];
|
||||
rgb[idx * 3 + 0] = tmp.r;
|
||||
rgb[idx * 3 + 1] = tmp.g;
|
||||
rgb[idx * 3 + 2] = tmp.b;
|
||||
}
|
||||
return rgb;
|
||||
}
|
||||
|
||||
function readVar(name, fallback) {
|
||||
const v = getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
return v || fallback;
|
||||
@ -87,13 +166,16 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
// non-null points into the prefix and call setDrawRange(0, count).
|
||||
const maxN = data.point_ids.length;
|
||||
const positions = new Float32Array(maxN * 3);
|
||||
const colors = new Float32Array(maxN * 3);
|
||||
const ids = new Int32Array(maxN); // packed point_id per drawn vertex
|
||||
const idColorRGB = buildIdColorsRGB(data);
|
||||
|
||||
const geo = new THREE.BufferGeometry();
|
||||
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
||||
geo.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
||||
geo.setDrawRange(0, 0);
|
||||
|
||||
const mat = new THREE.PointsMaterial({
|
||||
const matMono = new THREE.PointsMaterial({
|
||||
size: 6.0,
|
||||
sizeAttenuation: false,
|
||||
map: DISK_TEX,
|
||||
@ -102,6 +184,19 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
depthWrite: false,
|
||||
color: new THREE.Color(panelAccent(slotId)),
|
||||
});
|
||||
const matRainbow = new THREE.PointsMaterial({
|
||||
size: 6.0,
|
||||
sizeAttenuation: false,
|
||||
map: DISK_TEX,
|
||||
// Opaque + alpha cutoff rather than alpha blending: with a ramp, many
|
||||
// overlapping translucent points average out to the middle of the ramp
|
||||
// (blue + orange → green), washing out the whole cloud. Hard edges keep
|
||||
// each point's true color visible.
|
||||
transparent: false,
|
||||
alphaTest: 0.5,
|
||||
vertexColors: true,
|
||||
});
|
||||
let mat = matMono;
|
||||
const points = new THREE.Points(geo, mat);
|
||||
scene.add(points);
|
||||
|
||||
@ -185,11 +280,26 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
}
|
||||
|
||||
function applyColorsFromTheme() {
|
||||
mat.color.set(panelAccent(slotId));
|
||||
matMono.color.set(panelAccent(slotId));
|
||||
hiMat.color.set(highlightColor());
|
||||
scene.background = new THREE.Color(readPanelBg(panelEl));
|
||||
}
|
||||
|
||||
function setColorMode(mode) {
|
||||
const next = mode === 'mono' ? matMono : matRainbow;
|
||||
if (points.material !== next) {
|
||||
points.material = next;
|
||||
mat = next;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the precomputed per-id RGB into the packed position `j`.
|
||||
function writePackedColor(j, i) {
|
||||
colors[j * 3 + 0] = idColorRGB[i * 3 + 0];
|
||||
colors[j * 3 + 1] = idColorRGB[i * 3 + 1];
|
||||
colors[j * 3 + 2] = idColorRGB[i * 3 + 2];
|
||||
}
|
||||
|
||||
// Pack frame `f` into the geometry buffer, skipping null x/y.
|
||||
function setFrame(f) {
|
||||
const frame = data.frames[f];
|
||||
@ -204,6 +314,7 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
positions[j * 3 + 0] = x;
|
||||
positions[j * 3 + 1] = y;
|
||||
positions[j * 3 + 2] = 0;
|
||||
writePackedColor(j, i);
|
||||
packedX[j] = x;
|
||||
packedY[j] = y;
|
||||
packedId[j] = ptIds[i];
|
||||
@ -211,6 +322,7 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
}
|
||||
packedN = j;
|
||||
geo.attributes.position.needsUpdate = true;
|
||||
geo.attributes.color.needsUpdate = true;
|
||||
geo.setDrawRange(0, packedN);
|
||||
applyHighlightForCurrentFrame();
|
||||
}
|
||||
@ -241,6 +353,7 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
positions[j * 3 + 0] = xi;
|
||||
positions[j * 3 + 1] = yi;
|
||||
positions[j * 3 + 2] = 0;
|
||||
writePackedColor(j, i);
|
||||
packedX[j] = xi;
|
||||
packedY[j] = yi;
|
||||
packedId[j] = ptIds[i];
|
||||
@ -248,6 +361,7 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
}
|
||||
packedN = j;
|
||||
geo.attributes.position.needsUpdate = true;
|
||||
geo.attributes.color.needsUpdate = true;
|
||||
geo.setDrawRange(0, packedN);
|
||||
applyHighlightForCurrentFrame();
|
||||
}
|
||||
@ -321,7 +435,8 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
function dispose() {
|
||||
geo.dispose();
|
||||
hiGeo.dispose();
|
||||
mat.dispose();
|
||||
matMono.dispose();
|
||||
matRainbow.dispose();
|
||||
hiMat.dispose();
|
||||
renderer.dispose();
|
||||
if (renderer.domElement.parentNode) {
|
||||
@ -339,6 +454,7 @@ function createPanel({ slotId, panelEl, data }) {
|
||||
data,
|
||||
setFrame,
|
||||
setFrameInterpolated,
|
||||
setColorMode,
|
||||
setBounds,
|
||||
setHighlight,
|
||||
resize,
|
||||
@ -410,7 +526,7 @@ function markParamDiffs(metaA, metaB) {
|
||||
// -------- main ------------------------------------------------------------
|
||||
|
||||
async function fetchFrames(stem) {
|
||||
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`);
|
||||
const res = await fetch(`/api/runs/${encodeURIComponent(stem)}/frames.json`, { cache: 'no-store' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`${res.status} ${res.statusText}`);
|
||||
}
|
||||
@ -571,6 +687,13 @@ async function main() {
|
||||
applyU(parseFloat(scrub.value) / SCRUB_MAX);
|
||||
});
|
||||
|
||||
function applyColorMode() {
|
||||
const mode = colorSel.value;
|
||||
for (const p of Object.values(panels)) p?.setColorMode(mode);
|
||||
}
|
||||
colorSel.addEventListener('change', applyColorMode);
|
||||
applyColorMode();
|
||||
|
||||
// ---- linked hover -----------------------------------------------------
|
||||
function wireHover(pA, pB) {
|
||||
if (!pA) return;
|
||||
|
||||
@ -1635,7 +1635,8 @@ button.submit:disabled { background: var(--faint); border-color: var(--faint); c
|
||||
|
||||
.compare-controls .cc-speed-wrap,
|
||||
.compare-controls .cc-sync-wrap,
|
||||
.compare-controls .cc-motion-wrap {
|
||||
.compare-controls .cc-motion-wrap,
|
||||
.compare-controls .cc-color-wrap {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||
<title>embedding notebook · compare</title>
|
||||
<link rel="stylesheet" href="/static/style.css?v=25" />
|
||||
<link rel="stylesheet" href="/static/style.css?v=26" />
|
||||
<script type="importmap">
|
||||
{
|
||||
"imports": {
|
||||
@ -67,6 +67,14 @@
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label class="cc-color-wrap">
|
||||
<span class="cc-lbl">color</span>
|
||||
<select class="cc-color" id="cc-color">
|
||||
<option value="mono" selected>mono</option>
|
||||
<option value="original">original</option>
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label class="cc-sync-wrap">
|
||||
<span class="cc-lbl">axes</span>
|
||||
<select class="cc-sync" id="cc-sync">
|
||||
@ -111,6 +119,6 @@
|
||||
</section>
|
||||
|
||||
<script src="/static/theme.js?v=11"></script>
|
||||
<script type="module" src="/static/compare.js?v=7"></script>
|
||||
<script type="module" src="/static/compare.js?v=11"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||
<title>embedding notebook</title>
|
||||
<link rel="stylesheet" href="/static/style.css?v=25" />
|
||||
<link rel="stylesheet" href="/static/style.css?v=26" />
|
||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||
<script type="importmap">
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user