dr-sandbox/flows/embedding_flow.py
Michael Pilosov e94d28b8fc filenames + run names: J in sci notation (5E-3 not 0.005)
Periods in filenames are avoidable and the Prefect UI dislikes them in
run names. Uses a shared sci_notation helper in main.py mirrored in the
flow. Stem regex (main + parser) now matches J<digits.Ee+-> to accept
both old decimal-J and new sci-J filenames so the two transition
together. J tag in Prefect tag list also uses the sci form, so chip
filters stay consistent.

Backfill script extended to find pre-transition (decimal-J) files on
disk via a second base-stem variant, then rename them to the sci form.
backfill_tags re-patches existing runs so their J tag matches the new
canonical form.

All 13 existing figs + runs renamed / retagged in-place.
2026-04-22 17:54:46 -06:00

467 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# embedding_flow.py
import os
import sys
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
# in the environment still wins (setdefault is a no-op if the key exists).
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
os.environ.setdefault("DO_NOT_TRACK", "1")
# Pin per-process thread pools to 1 so Ray's worker parallelism doesn't
# multiply against BLAS/numba/etc. thread pools — otherwise 4 workers × N
# cores → thrash. Must be set before numpy/numba/sklearn import, since
# those libs latch onto these env vars at import time. Ray manages OMP
# per-task-CPU but does NOT manage NUMBA_NUM_THREADS, which is what
# PaCMAP/UMAP use for their optimization loops.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
from datetime import timedelta
import hashlib
import json
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
def _run_args_hash(
ea: Optional[Dict[str, Any]],
gk: Optional[Dict[str, Any]] = None,
) -> str:
"""8-hex digest over (embed_args, generator_kwargs). When gk is empty we
hash embed_args alone — keeps stems stable for plain generators that
never had gen_kwargs (s_curve, plain swiss_roll). Must mirror
app.web.main.run_args_hash exactly."""
if gk:
payload: Any = {"embed_args": ea or {}, "generator_kwargs": gk}
else:
payload = ea or {}
s = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8]
def _sci(v: Any) -> str:
"""Float → compact sci notation without a period (e.g. 0.005 → 5E-3,
0.01 → 1E-2). Keeps Prefect's UI happy — it doesn't like periods in
run names."""
try:
f = float(v)
except (TypeError, ValueError):
return str(v)
m, e = f"{f:.3e}".split("e")
m = m.rstrip("0").rstrip(".")
return f"{m}E{int(e)}"
def _flow_run_name() -> str:
"""Name each Prefect run after the stem of its output fig, so runs are
searchable / hoverable instead of wearing Prefect's auto-generated
adjective-animal names."""
p = runtime.flow_run.parameters or {}
gen = (p.get("generator_path") or "").rsplit(".", 1)[-1] or "?"
emb = (p.get("embedder") or "").rsplit(".", 1)[-1] or "?"
N = p.get("num_points", "?")
T = p.get("num_timesteps", "?")
J = _sci(p.get("jitter_scale", "?"))
s = p.get("seed", "?")
tag = _run_args_hash(p.get("embed_args"), p.get("generator_kwargs"))
return f"{gen}_{emb}_N{N}_T{T}_J{J}_s{s}_{tag}"
from prefect import flow, runtime, task
from prefect.artifacts import create_markdown_artifact, create_table_artifact
from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner
import pandas as pd
from sklearn.preprocessing import StandardScaler
import embedding_utils as E
from joblib import cpu_count
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
def generate_initial_frame_task(
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
) -> pd.DataFrame:
"""
Generate the initial data frame using a specified data generator.
Parameters:
- generator_path: str
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
- generator_kwargs: Dict[str, Any]
Keyword arguments to pass to the data generator function.
- id_column: str
Column name to use as a unique identifier.
Returns:
- df: pd.DataFrame
DataFrame with generated data and unique IDs.
"""
generator_func = E.dynamic_import(generator_path)
data, labels = generator_func(**generator_kwargs)
# Per-feature z-score so jitter_scale has consistent meaning across
# generators and reducers see comparably-scaled inputs.
data = StandardScaler().fit_transform(data)
df = pd.DataFrame(
{
"feature_0": data[:, 0],
"feature_1": data[:, 1],
"feature_2": data[:, 2],
id_column: range(data.shape[0]),
"time": 0,
}
)
df[id_column] = df[id_column].astype(int)
return df
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
def generate_snapshots_task(
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
) -> List[pd.DataFrame]:
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
@task(
cache_policy=INPUTS,
cache_expiration=timedelta(days=1),
task_run_name="embed-{time_idx}",
)
def create_embedding(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedder: str,
embed_args: Dict[str, Any],
time_idx: str | int,
id_column: str = "id",
) -> pd.DataFrame:
return E.create_embedding_dataframe(
snapshot=snapshot,
embed_columns=embed_columns,
embedding_algorithm_str=embedder,
embedding_kwargs=embed_args,
id_column=id_column,
time_idx=time_idx,
)
@task
def collect_data_task(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
return E.collect_and_prepare_for_plotly(
embedded_dfs, sort_time=sort_time, id_column=id_column
)
def _fmt(v: Any, spec: str = ".4f") -> str:
return format(v, spec) if isinstance(v, (int, float)) else ""
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
return float(sum(vals) / len(vals)) if vals else None
@task(
task_run_name="metrics-{output_path}",
retries=1,
cache_policy=NO_CACHE,
)
def compute_metrics_task(
snapshot_list: List[pd.DataFrame],
embedded_dfs: List[pd.DataFrame],
embed_columns: List[str],
output_path: str,
meta: Dict[str, Any],
id_column: str = "id",
k: int = 10,
) -> str:
metrics = E.compute_metrics(
snapshot_list=snapshot_list,
embedded_list=embedded_dfs,
embed_columns=embed_columns,
id_column=id_column,
k=k,
)
payload = {"meta": meta, **metrics}
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(payload, f, indent=2)
# --- Prefect artifacts ---
ff = metrics["travel"]["frame_to_frame"]
vi = metrics["travel"]["vs_initial"]
kn = metrics["knn_retention"]
ff_by_t = {r["t"]: r for r in ff}
vi_by_t = {r["t"]: r for r in vi}
kn_by_t = {r["t"]: r for r in kn}
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
rows = []
for t in all_t:
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
rows.append({
"t": t,
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
})
gen_short = meta["generator_path"].split(".")[-1]
emb_short = meta["embedder"].split(".")[-1]
desc = (
f"`{emb_short}` on `{gen_short}` — "
f"N={meta['num_points']} T={meta['num_timesteps']} "
f"J={meta['jitter_scale']} s={meta['seed']}"
)
create_table_artifact(
key="embedding-metrics",
table=rows,
description=desc,
)
ff_last = ff[-1] if ff else {}
vi_last = vi[-1] if vi else {}
md = (
f"### {emb_short} on {gen_short}\n\n"
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
f"| window | mean | median | p95 | max |\n"
f"|---|---|---|---|---|\n"
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
f"{_fmt(_mean_of(ff, 'max'))} |\n"
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
f"{_fmt(vi_last.get('max'))} |\n\n"
f"**kNN retention** (k={metrics['k']}): "
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
f"_Sidecar JSON:_ `{output_path}`\n"
)
create_markdown_artifact(
key="embedding-metrics-summary",
markdown=md,
description=desc,
)
return output_path
@task(
task_run_name="plot-{output_path}",
retries=3,
cache_policy=NO_CACHE,
)
def plot_and_save_task(
combined_df: pd.DataFrame,
title: str,
output_path: str,
frame_duration: int = 500,
transition_duration: int = 500,
fixed_axes: bool = True,
equal_aspect: bool = True,
samples: int = 25_000,
):
fig = E.plot_embedding_over_time(
combined_df,
title=title,
frame_duration=int(frame_duration),
transition_duration=int(transition_duration),
fixed_axes=fixed_axes,
equal_aspect=equal_aspect,
samples=samples,
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
return output_path
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}), flow_run_name=_flow_run_name)
def embedding_flow(
num_points: int = 5000,
num_timesteps: int = 48,
jitter_scale: float = 0.01,
seed: int = 42,
generator_path: str = "sklearn.datasets.make_s_curve",
generator_kwargs: Optional[Dict[str, Any]] = None,
embed_columns: Optional[List[str]] = None,
embedder: str = "sklearn.decomposition.FactorAnalysis",
embed_args: Optional[Dict[str, Any]] = None,
output_dir: str = "figs",
id_column: str = "id",
frame_duration: int = 1200,
transition_duration: int = 2400,
reference_speedup: float = 10.0,
samples: int = 10_000,
):
# Preserve the user-supplied generator_kwargs for hashing / metadata —
# the merged dict (with random_state defaults + n_samples) goes to the
# generator itself but those aren't part of the run's semantic identity
# (random_state=0 is a flow constant; n_samples is captured as `N` in
# the stem). If the merged dict were hashed, the web app would disagree
# with the flow because Prefect only records the user-supplied form.
user_generator_kwargs = dict(generator_kwargs or {})
generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS,
**user_generator_kwargs,
"n_samples": num_points,
}
embed_columns = (
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
)
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
Path(output_dir).mkdir(parents=True, exist_ok=True)
_generator = generator_path.split(".")[-1]
_j = _sci(jitter_scale)
output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{_j}_s{seed}.html"
)
_args_tag = _run_args_hash(embed_args, user_generator_kwargs)
output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{_j}_s{seed}_{_args_tag}.html"
)
output_metrics: str = output_embed[:-5] + ".metrics.json"
output_frames: str = output_embed[:-5] + ".frames.json"
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
merged_embed_args = embed_args
# Generate initial frame using the specified data generator
initial_frame = generate_initial_frame_task.submit(
generator_path=generator_path,
generator_kwargs=generator_kwargs,
id_column=id_column,
)
# Generate snapshots
snapshots = generate_snapshots_task.submit(
initial_df=initial_frame.result(),
num_timesteps=num_timesteps,
jitter_scale=jitter_scale,
seed=seed,
)
snapshot_list = snapshots.result()
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
dates = [
f"{year}-{month:02d}-01"
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
for month in range(1, 13)
][:num_timesteps]
# Apply embeddings in parallel using Prefect's mapping
embeddings = create_embedding.map(
snapshot=snapshot_list,
time_idx=dates,
embed_columns=[embed_columns] * num_timesteps,
embedder=[embedder] * num_timesteps,
embed_args=[merged_embed_args] * num_timesteps,
id_column=[id_column] * num_timesteps,
)
# Collect all embeddings
combined_df = collect_data_task.submit(
embedded_dfs=embeddings.result(), sort_time=False
).result()
# make the original snapshots look like the embeddings
dfr = collect_data_task.submit(
embedded_dfs=snapshot_list, sort_time=False
).result()
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
dfr.columns = ["x", "y", id_column, "time"]
dfr["time"] = combined_df["time"].to_numpy()
# Plot reference animation
ref_path = plot_and_save_task.submit(
combined_df=dfr,
title=title_ref,
output_path=output_ref,
frame_duration=max(frame_duration / reference_speedup, 175),
transition_duration=max(transition_duration / reference_speedup, 350),
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Plot embedding animation
emb_path = plot_and_save_task.submit(
combined_df=combined_df,
title=title_embed,
output_path=output_embed,
frame_duration=frame_duration,
transition_duration=transition_duration,
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Sidecar stability metrics (travel + kNN retention). Runs in parallel
# with plotting; writes a JSON next to the embedding fig.
metrics_path = compute_metrics_task.submit(
snapshot_list=snapshot_list,
embedded_dfs=embeddings.result(),
embed_columns=embed_columns,
output_path=output_metrics,
meta={
"num_points": num_points,
"num_timesteps": num_timesteps,
"jitter_scale": jitter_scale,
"seed": seed,
"generator_path": generator_path,
"generator_kwargs": user_generator_kwargs,
"embedder": embedder,
"embed_args": merged_embed_args,
},
id_column=id_column,
k=10,
)
emb_path_result = emb_path.result()
metrics_path_result = metrics_path.result()
# Emit a frames.json sidecar so the compare page doesn't have to parse
# the 5 MB plotly HTML on every first request. Non-critical — the server
# falls back to HTML parsing when the sidecar is absent.
try:
import sys as _sys
_root = str(Path(__file__).resolve().parent.parent)
if _root not in _sys.path:
_sys.path.insert(0, _root)
from app.web.plotly_parse import parse_plotly_run
frames = parse_plotly_run(emb_path_result)
# Persist generator_kwargs so the server's label enrichment can
# regenerate the correct dataset variant (swiss_roll vs hole).
frames.setdefault("meta", {})["generator_kwargs"] = user_generator_kwargs
Path(output_frames).write_text(
json.dumps(frames, separators=(",", ":")), encoding="utf-8"
)
except Exception as _sidecar_err:
import traceback as _tb
print(f"[frames-sidecar] skipped: {_sidecar_err}")
_tb.print_exc()
return (ref_path.result(), emb_path_result, metrics_path_result)
if __name__ == "__main__":
embedding_flow.serve(limit=1)
# embedding_flow()