dr-sandbox/flows/embedding_flow.py
Michael Pilosov b744c48348 stems: fold generator_kwargs into the hash; fix swiss_roll vs hole ambiguity
- run_args_hash now covers (embed_args, generator_kwargs). When gen_kwargs
  is empty we still hash embed_args alone — so plain generators (s_curve,
  plain swiss_roll) keep their stems and no existing plain-gen figs need
  renaming. Kwargs-bearing variants (swiss_roll_hole, blobs,
  gaussian_quantiles, classification) now disambiguate properly.
- Flow persists generator_kwargs into metrics.json meta AND into the
  frames.json sidecar meta, so the label-enrichment path can find it
  without another lookup.
- _enrich_with_labels discovers gen_kwargs in priority: payload meta -->
  sibling metrics.json --> DATASET_META first-match. It matches the
  DATASET_META entry by (path, kwargs) so swiss_roll_hole is no longer
  confused for plain swiss_roll.
- _cached_frames overrides meta.stem with the URL-requested stem before
  enrichment — after a backfill rename the sidecar's baked-in stem is
  stale, and we were then failing to find the sibling metrics.json.
- Submit duplicate-check uses the new hash and keeps the hashless-legacy
  check as a safety net.
- backfill_hashes.py rewritten: queries Prefect for each recent run's
  full params, finds the matching fig under any of (current, legacy,
  hashless) names, renames to the current scheme and patches
  generator_kwargs into metrics.json.
2026-04-22 16:30:42 -06:00

446 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# embedding_flow.py
import os
import sys
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
# in the environment still wins (setdefault is a no-op if the key exists).
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
os.environ.setdefault("DO_NOT_TRACK", "1")
# Pin per-process thread pools to 1 so Ray's worker parallelism doesn't
# multiply against BLAS/numba/etc. thread pools — otherwise 4 workers × N
# cores → thrash. Must be set before numpy/numba/sklearn import, since
# those libs latch onto these env vars at import time. Ray manages OMP
# per-task-CPU but does NOT manage NUMBA_NUM_THREADS, which is what
# PaCMAP/UMAP use for their optimization loops.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
from datetime import timedelta
import hashlib
import json
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
def _run_args_hash(
ea: Optional[Dict[str, Any]],
gk: Optional[Dict[str, Any]] = None,
) -> str:
"""8-hex digest over (embed_args, generator_kwargs). When gk is empty we
hash embed_args alone — keeps stems stable for plain generators that
never had gen_kwargs (s_curve, plain swiss_roll). Must mirror
app.web.main.run_args_hash exactly."""
if gk:
payload: Any = {"embed_args": ea or {}, "generator_kwargs": gk}
else:
payload = ea or {}
s = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8]
def _flow_run_name() -> str:
"""Name each Prefect run after the stem of its output fig, so runs are
searchable / hoverable instead of wearing Prefect's auto-generated
adjective-animal names."""
p = runtime.flow_run.parameters or {}
gen = (p.get("generator_path") or "").rsplit(".", 1)[-1] or "?"
emb = (p.get("embedder") or "").rsplit(".", 1)[-1] or "?"
N = p.get("num_points", "?")
T = p.get("num_timesteps", "?")
J = p.get("jitter_scale", "?")
s = p.get("seed", "?")
tag = _run_args_hash(p.get("embed_args"), p.get("generator_kwargs"))
return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}_{tag}"
from prefect import flow, runtime, task
from prefect.artifacts import create_markdown_artifact, create_table_artifact
from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner
import pandas as pd
from sklearn.preprocessing import StandardScaler
import embedding_utils as E
from joblib import cpu_count
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
def generate_initial_frame_task(
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
) -> pd.DataFrame:
"""
Generate the initial data frame using a specified data generator.
Parameters:
- generator_path: str
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
- generator_kwargs: Dict[str, Any]
Keyword arguments to pass to the data generator function.
- id_column: str
Column name to use as a unique identifier.
Returns:
- df: pd.DataFrame
DataFrame with generated data and unique IDs.
"""
generator_func = E.dynamic_import(generator_path)
data, labels = generator_func(**generator_kwargs)
# Per-feature z-score so jitter_scale has consistent meaning across
# generators and reducers see comparably-scaled inputs.
data = StandardScaler().fit_transform(data)
df = pd.DataFrame(
{
"feature_0": data[:, 0],
"feature_1": data[:, 1],
"feature_2": data[:, 2],
id_column: range(data.shape[0]),
"time": 0,
}
)
df[id_column] = df[id_column].astype(int)
return df
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
def generate_snapshots_task(
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
) -> List[pd.DataFrame]:
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
@task(
cache_policy=INPUTS,
cache_expiration=timedelta(days=1),
task_run_name="embed-{time_idx}",
)
def create_embedding(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedder: str,
embed_args: Dict[str, Any],
time_idx: str | int,
id_column: str = "id",
) -> pd.DataFrame:
return E.create_embedding_dataframe(
snapshot=snapshot,
embed_columns=embed_columns,
embedding_algorithm_str=embedder,
embedding_kwargs=embed_args,
id_column=id_column,
time_idx=time_idx,
)
@task
def collect_data_task(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
return E.collect_and_prepare_for_plotly(
embedded_dfs, sort_time=sort_time, id_column=id_column
)
def _fmt(v: Any, spec: str = ".4f") -> str:
return format(v, spec) if isinstance(v, (int, float)) else ""
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
return float(sum(vals) / len(vals)) if vals else None
@task(
task_run_name="metrics-{output_path}",
retries=1,
cache_policy=NO_CACHE,
)
def compute_metrics_task(
snapshot_list: List[pd.DataFrame],
embedded_dfs: List[pd.DataFrame],
embed_columns: List[str],
output_path: str,
meta: Dict[str, Any],
id_column: str = "id",
k: int = 10,
) -> str:
metrics = E.compute_metrics(
snapshot_list=snapshot_list,
embedded_list=embedded_dfs,
embed_columns=embed_columns,
id_column=id_column,
k=k,
)
payload = {"meta": meta, **metrics}
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(payload, f, indent=2)
# --- Prefect artifacts ---
ff = metrics["travel"]["frame_to_frame"]
vi = metrics["travel"]["vs_initial"]
kn = metrics["knn_retention"]
ff_by_t = {r["t"]: r for r in ff}
vi_by_t = {r["t"]: r for r in vi}
kn_by_t = {r["t"]: r for r in kn}
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
rows = []
for t in all_t:
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
rows.append({
"t": t,
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
})
gen_short = meta["generator_path"].split(".")[-1]
emb_short = meta["embedder"].split(".")[-1]
desc = (
f"`{emb_short}` on `{gen_short}` — "
f"N={meta['num_points']} T={meta['num_timesteps']} "
f"J={meta['jitter_scale']} s={meta['seed']}"
)
create_table_artifact(
key="embedding-metrics",
table=rows,
description=desc,
)
ff_last = ff[-1] if ff else {}
vi_last = vi[-1] if vi else {}
md = (
f"### {emb_short} on {gen_short}\n\n"
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
f"| window | mean | median | p95 | max |\n"
f"|---|---|---|---|---|\n"
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
f"{_fmt(_mean_of(ff, 'max'))} |\n"
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
f"{_fmt(vi_last.get('max'))} |\n\n"
f"**kNN retention** (k={metrics['k']}): "
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
f"_Sidecar JSON:_ `{output_path}`\n"
)
create_markdown_artifact(
key="embedding-metrics-summary",
markdown=md,
description=desc,
)
return output_path
@task(
task_run_name="plot-{output_path}",
retries=3,
cache_policy=NO_CACHE,
)
def plot_and_save_task(
combined_df: pd.DataFrame,
title: str,
output_path: str,
frame_duration: int = 500,
transition_duration: int = 500,
fixed_axes: bool = True,
equal_aspect: bool = True,
samples: int = 25_000,
):
fig = E.plot_embedding_over_time(
combined_df,
title=title,
frame_duration=int(frame_duration),
transition_duration=int(transition_duration),
fixed_axes=fixed_axes,
equal_aspect=equal_aspect,
samples=samples,
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
return output_path
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}), flow_run_name=_flow_run_name)
def embedding_flow(
num_points: int = 5000,
num_timesteps: int = 48,
jitter_scale: float = 0.01,
seed: int = 42,
generator_path: str = "sklearn.datasets.make_s_curve",
generator_kwargs: Optional[Dict[str, Any]] = None,
embed_columns: Optional[List[str]] = None,
embedder: str = "sklearn.decomposition.FactorAnalysis",
embed_args: Optional[Dict[str, Any]] = None,
output_dir: str = "figs",
id_column: str = "id",
frame_duration: int = 1200,
transition_duration: int = 2400,
reference_speedup: float = 10.0,
samples: int = 10_000,
):
generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS,
**(generator_kwargs or {}),
"n_samples": num_points,
}
embed_columns = (
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
)
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
Path(output_dir).mkdir(parents=True, exist_ok=True)
_generator = generator_path.split(".")[-1]
output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
_args_tag = _run_args_hash(embed_args, generator_kwargs)
output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html"
)
output_metrics: str = output_embed[:-5] + ".metrics.json"
output_frames: str = output_embed[:-5] + ".frames.json"
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
merged_embed_args = embed_args
# Generate initial frame using the specified data generator
initial_frame = generate_initial_frame_task.submit(
generator_path=generator_path,
generator_kwargs=generator_kwargs,
id_column=id_column,
)
# Generate snapshots
snapshots = generate_snapshots_task.submit(
initial_df=initial_frame.result(),
num_timesteps=num_timesteps,
jitter_scale=jitter_scale,
seed=seed,
)
snapshot_list = snapshots.result()
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
dates = [
f"{year}-{month:02d}-01"
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
for month in range(1, 13)
][:num_timesteps]
# Apply embeddings in parallel using Prefect's mapping
embeddings = create_embedding.map(
snapshot=snapshot_list,
time_idx=dates,
embed_columns=[embed_columns] * num_timesteps,
embedder=[embedder] * num_timesteps,
embed_args=[merged_embed_args] * num_timesteps,
id_column=[id_column] * num_timesteps,
)
# Collect all embeddings
combined_df = collect_data_task.submit(
embedded_dfs=embeddings.result(), sort_time=False
).result()
# make the original snapshots look like the embeddings
dfr = collect_data_task.submit(
embedded_dfs=snapshot_list, sort_time=False
).result()
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
dfr.columns = ["x", "y", id_column, "time"]
dfr["time"] = combined_df["time"].to_numpy()
# Plot reference animation
ref_path = plot_and_save_task.submit(
combined_df=dfr,
title=title_ref,
output_path=output_ref,
frame_duration=max(frame_duration / reference_speedup, 175),
transition_duration=max(transition_duration / reference_speedup, 350),
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Plot embedding animation
emb_path = plot_and_save_task.submit(
combined_df=combined_df,
title=title_embed,
output_path=output_embed,
frame_duration=frame_duration,
transition_duration=transition_duration,
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Sidecar stability metrics (travel + kNN retention). Runs in parallel
# with plotting; writes a JSON next to the embedding fig.
metrics_path = compute_metrics_task.submit(
snapshot_list=snapshot_list,
embedded_dfs=embeddings.result(),
embed_columns=embed_columns,
output_path=output_metrics,
meta={
"num_points": num_points,
"num_timesteps": num_timesteps,
"jitter_scale": jitter_scale,
"seed": seed,
"generator_path": generator_path,
"generator_kwargs": generator_kwargs or {},
"embedder": embedder,
"embed_args": merged_embed_args,
},
id_column=id_column,
k=10,
)
emb_path_result = emb_path.result()
metrics_path_result = metrics_path.result()
# Emit a frames.json sidecar so the compare page doesn't have to parse
# the 5 MB plotly HTML on every first request. Non-critical — the server
# falls back to HTML parsing when the sidecar is absent.
try:
import sys as _sys
_root = str(Path(__file__).resolve().parent.parent)
if _root not in _sys.path:
_sys.path.insert(0, _root)
from app.web.plotly_parse import parse_plotly_run
frames = parse_plotly_run(emb_path_result)
# Persist generator_kwargs so the server's label enrichment can
# regenerate the correct dataset variant (swiss_roll vs hole).
frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {}
Path(output_frames).write_text(
json.dumps(frames, separators=(",", ":")), encoding="utf-8"
)
except Exception as _sidecar_err:
import traceback as _tb
print(f"[frames-sidecar] skipped: {_sidecar_err}")
_tb.print_exc()
return (ref_path.result(), emb_path_result, metrics_path_result)
if __name__ == "__main__":
embedding_flow.serve(limit=1)
# embedding_flow()