diff --git a/app/web/main.py b/app/web/main.py index 006143b..6cae799 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -13,6 +13,7 @@ from __future__ import annotations import importlib.util import json import os +from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -21,6 +22,13 @@ from fastapi import FastAPI, Form, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from sklearn.datasets import ( + make_blobs, + make_classification, + make_gaussian_quantiles, + make_s_curve, + make_swiss_roll, +) # --------------------------------------------------------------------------- @@ -35,11 +43,137 @@ FIGS_DIR.mkdir(parents=True, exist_ok=True) PREFECT_API = os.environ.get("PREFECT_API_URL", "http://localhost:4200/api") DEPLOYMENT_NAME = "embedding-flow/embedding-flow" -GENERATOR_OPTIONS = [ - ("sklearn.datasets.make_s_curve", "make_s_curve"), - ("sklearn.datasets.make_swiss_roll", "make_swiss_roll"), - ("sklearn.datasets.make_blobs", "make_blobs"), -] + +# --------------------------------------------------------------------------- +# Dataset catalogue +# --------------------------------------------------------------------------- +# Metadata for the /data.json endpoint consumed by the dataset picker, and +# for server-side lookup when the picker posts its selection back. kwargs +# must carry n_features=3 for generators that aren't already 3-D, since +# they'll be forwarded verbatim to the Prefect flow's generator_kwargs. + +DATASET_PREVIEW_N = 5000 +DATASET_PREVIEW_SEED = 0 + +DATASET_META: Dict[str, Dict[str, Any]] = { + "s_curve": { + "name": "S-Curve", + "path": "sklearn.datasets.make_s_curve", + "kwargs": {}, + "description": ( + "A 2-D manifold warped into R³. Continuous label encodes position " + "along the curve — a good test of whether a reducer unrolls the " + "sheet without tearing." + ), + "kind": "continuous", + }, + "swiss_roll": { + "name": "Swiss Roll", + "path": "sklearn.datasets.make_swiss_roll", + "kwargs": {}, + "description": ( + "A rolled-up plane. The canonical hard case for linear methods: " + "PCA collapses the spiral, non-linear methods should recover the " + "unroll." + ), + "kind": "continuous", + }, + "swiss_roll_hole": { + "name": "Swiss Roll (hole)", + "path": "sklearn.datasets.make_swiss_roll", + "kwargs": {"hole": True}, + "description": ( + "Swiss roll with a rectangular hole punched through. Same manifold, " + "non-trivial topology — a faithful unroll should preserve the hole " + "rather than smearing it closed." + ), + "kind": "continuous", + }, + "blobs": { + "name": "Gaussian Blobs", + "path": "sklearn.datasets.make_blobs", + "kwargs": {"n_features": 3, "centers": 5, "cluster_std": 1.0}, + "description": ( + "Five isotropic Gaussian clusters in R³. Discrete class labels. " + "Tests whether a reducer preserves cluster separation when " + "projected to 2-D." + ), + "kind": "categorical", + }, + "gaussian_quantiles": { + "name": "Gaussian Quantiles", + "path": "sklearn.datasets.make_gaussian_quantiles", + "kwargs": {"n_features": 3, "n_classes": 4}, + "description": ( + "Concentric Gaussian shells in R³; class = which shell. Classes " + "are linearly inseparable by construction — PCA collapses them, " + "kernel and manifold methods have a chance." + ), + "kind": "categorical", + }, + "classification": { + "name": "Hypercube Clusters", + "path": "sklearn.datasets.make_classification", + "kwargs": { + "n_features": 3, + "n_informative": 3, + "n_redundant": 0, + "n_repeated": 0, + "n_classes": 4, + "n_clusters_per_class": 2, + "class_sep": 1.5, + }, + "description": ( + "Four classes, two sub-clusters each, placed at hypercube vertices " + "with informative noise. A denser discrete test than blobs — " + "within-class bimodality stresses cluster-preserving reducers." + ), + "kind": "categorical", + }, +} + + +@lru_cache(maxsize=1) +def _dataset_previews() -> Dict[str, Dict[str, Any]]: + """Attach freshly-generated points+labels to the catalogue for the picker.""" + N, SEED = DATASET_PREVIEW_N, DATASET_PREVIEW_SEED + s, sl = make_s_curve(n_samples=N, noise=0.03, random_state=SEED) + sr, srl = make_swiss_roll(n_samples=N, noise=0.15, random_state=SEED) + srh, srhl = make_swiss_roll(n_samples=N, noise=0.15, hole=True, random_state=SEED) + b, bl = make_blobs( + n_samples=N, n_features=3, centers=5, cluster_std=1.0, random_state=SEED + ) + gq, gql = make_gaussian_quantiles( + n_samples=N, n_features=3, n_classes=4, random_state=SEED + ) + cls, clsl = make_classification( + n_samples=N, + n_features=3, + n_informative=3, + n_redundant=0, + n_repeated=0, + n_classes=4, + n_clusters_per_class=2, + class_sep=1.5, + random_state=SEED, + ) + samples = { + "s_curve": (s, sl), + "swiss_roll": (sr, srl), + "swiss_roll_hole": (srh, srhl), + "blobs": (b, bl), + "gaussian_quantiles": (gq, gql), + "classification": (cls, clsl), + } + out: Dict[str, Dict[str, Any]] = {} + for key, meta in DATASET_META.items(): + pts, labels = samples[key] + out[key] = { + **meta, + "points": pts.tolist(), + "labels": labels.tolist(), + } + return out # --------------------------------------------------------------------------- @@ -394,7 +528,6 @@ async def index(request: Request) -> HTMLResponse: "reducers": reducers, "default_reducer": default_reducer, "default_spec": default_spec, - "generators": GENERATOR_OPTIONS, "runs": views, "deployment_id": dep_id, "prefect_api": PREFECT_API, @@ -402,6 +535,11 @@ async def index(request: Request) -> HTMLResponse: ) +@app.get("/data.json") +async def data_json() -> JSONResponse: + return JSONResponse(_dataset_previews()) + + @app.get("/reducer-form", response_class=HTMLResponse) async def reducer_form(request: Request, name: str) -> HTMLResponse: spec = REDUCERS.get(name) @@ -436,7 +574,24 @@ async def submit(request: Request) -> HTMLResponse: status_code=400, ) - # Data params + # Dataset came from the picker via dataset_id; fall back to explicit + # generator_path / generator_kwargs if a client posts those directly. + dataset_id = data.get("dataset_id") or "" + if dataset_id and dataset_id in DATASET_META: + meta = DATASET_META[dataset_id] + generator_path = meta["path"] + generator_kwargs = dict(meta["kwargs"]) + else: + generator_path = data.get("generator_path") or "sklearn.datasets.make_s_curve" + raw_kwargs = data.get("generator_kwargs") or "" + try: + generator_kwargs = json.loads(raw_kwargs) if raw_kwargs else {} + except json.JSONDecodeError as e: + return HTMLResponse( + f"