dr-sandbox/flows/embedding_flow.py
Michael Pilosov fe49565651 stems: include embed_args hash in output filename + emit frames.json sidecar
Stem grows an 8-hex sha1 digest of the (keys-sorted) embed_args dict, so
runs differing only in embed_args (e.g. UMAP n_neighbors=5 vs 15) now
produce distinct figs. The stem regex and parser both accept an optional
_<hash> tail so pre-hash figs still render in the runs list and compare
page; legacy filename is resolved on disk fallback.

Duplicate-submission check now rejects against BOTH the hashed and the
legacy hashless variant so users can't accidentally duplicate an old run
either.

Flow additionally writes a <stem>.frames.json sidecar next to the plotly
HTML (same shape as app/web/plotly_parse returns). Server prefers the
sidecar when present; falls back to parsing HTML for older runs. Sidecar
emission is non-critical — any failure just logs and keeps going.
2026-04-22 15:52:39 -06:00

418 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# embedding_flow.py
import os
import sys
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
# in the environment still wins (setdefault is a no-op if the key exists).
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
os.environ.setdefault("DO_NOT_TRACK", "1")
# Pin per-process thread pools to 1 so Ray's worker parallelism doesn't
# multiply against BLAS/numba/etc. thread pools — otherwise 4 workers × N
# cores → thrash. Must be set before numpy/numba/sklearn import, since
# those libs latch onto these env vars at import time. Ray manages OMP
# per-task-CPU but does NOT manage NUMBA_NUM_THREADS, which is what
# PaCMAP/UMAP use for their optimization loops.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
from datetime import timedelta
import hashlib
import json
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
def _embed_args_hash(ea: Optional[Dict[str, Any]]) -> str:
"""8-hex digest of embed_args (keys sorted) — output stem includes this
so runs differing only in embed_args get distinct files."""
s = json.dumps(ea or {}, sort_keys=True, default=str)
return hashlib.sha1(s.encode()).hexdigest()[:8]
from prefect import flow, task
from prefect.artifacts import create_markdown_artifact, create_table_artifact
from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner
import pandas as pd
from sklearn.preprocessing import StandardScaler
import embedding_utils as E
from joblib import cpu_count
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
def generate_initial_frame_task(
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
) -> pd.DataFrame:
"""
Generate the initial data frame using a specified data generator.
Parameters:
- generator_path: str
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
- generator_kwargs: Dict[str, Any]
Keyword arguments to pass to the data generator function.
- id_column: str
Column name to use as a unique identifier.
Returns:
- df: pd.DataFrame
DataFrame with generated data and unique IDs.
"""
generator_func = E.dynamic_import(generator_path)
data, labels = generator_func(**generator_kwargs)
# Per-feature z-score so jitter_scale has consistent meaning across
# generators and reducers see comparably-scaled inputs.
data = StandardScaler().fit_transform(data)
df = pd.DataFrame(
{
"feature_0": data[:, 0],
"feature_1": data[:, 1],
"feature_2": data[:, 2],
id_column: range(data.shape[0]),
"time": 0,
}
)
df[id_column] = df[id_column].astype(int)
return df
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
def generate_snapshots_task(
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
) -> List[pd.DataFrame]:
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
@task(
cache_policy=INPUTS,
cache_expiration=timedelta(days=1),
task_run_name="embed-{time_idx}",
)
def create_embedding(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedder: str,
embed_args: Dict[str, Any],
time_idx: str | int,
id_column: str = "id",
) -> pd.DataFrame:
return E.create_embedding_dataframe(
snapshot=snapshot,
embed_columns=embed_columns,
embedding_algorithm_str=embedder,
embedding_kwargs=embed_args,
id_column=id_column,
time_idx=time_idx,
)
@task
def collect_data_task(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
return E.collect_and_prepare_for_plotly(
embedded_dfs, sort_time=sort_time, id_column=id_column
)
def _fmt(v: Any, spec: str = ".4f") -> str:
return format(v, spec) if isinstance(v, (int, float)) else ""
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
return float(sum(vals) / len(vals)) if vals else None
@task(
task_run_name="metrics-{output_path}",
retries=1,
cache_policy=NO_CACHE,
)
def compute_metrics_task(
snapshot_list: List[pd.DataFrame],
embedded_dfs: List[pd.DataFrame],
embed_columns: List[str],
output_path: str,
meta: Dict[str, Any],
id_column: str = "id",
k: int = 10,
) -> str:
metrics = E.compute_metrics(
snapshot_list=snapshot_list,
embedded_list=embedded_dfs,
embed_columns=embed_columns,
id_column=id_column,
k=k,
)
payload = {"meta": meta, **metrics}
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(payload, f, indent=2)
# --- Prefect artifacts ---
ff = metrics["travel"]["frame_to_frame"]
vi = metrics["travel"]["vs_initial"]
kn = metrics["knn_retention"]
ff_by_t = {r["t"]: r for r in ff}
vi_by_t = {r["t"]: r for r in vi}
kn_by_t = {r["t"]: r for r in kn}
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
rows = []
for t in all_t:
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
rows.append({
"t": t,
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
})
gen_short = meta["generator_path"].split(".")[-1]
emb_short = meta["embedder"].split(".")[-1]
desc = (
f"`{emb_short}` on `{gen_short}` — "
f"N={meta['num_points']} T={meta['num_timesteps']} "
f"J={meta['jitter_scale']} s={meta['seed']}"
)
create_table_artifact(
key="embedding-metrics",
table=rows,
description=desc,
)
ff_last = ff[-1] if ff else {}
vi_last = vi[-1] if vi else {}
md = (
f"### {emb_short} on {gen_short}\n\n"
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
f"| window | mean | median | p95 | max |\n"
f"|---|---|---|---|---|\n"
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
f"{_fmt(_mean_of(ff, 'max'))} |\n"
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
f"{_fmt(vi_last.get('max'))} |\n\n"
f"**kNN retention** (k={metrics['k']}): "
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
f"_Sidecar JSON:_ `{output_path}`\n"
)
create_markdown_artifact(
key="embedding-metrics-summary",
markdown=md,
description=desc,
)
return output_path
@task(
task_run_name="plot-{output_path}",
retries=3,
cache_policy=NO_CACHE,
)
def plot_and_save_task(
combined_df: pd.DataFrame,
title: str,
output_path: str,
frame_duration: int = 500,
transition_duration: int = 500,
fixed_axes: bool = True,
equal_aspect: bool = True,
samples: int = 25_000,
):
fig = E.plot_embedding_over_time(
combined_df,
title=title,
frame_duration=int(frame_duration),
transition_duration=int(transition_duration),
fixed_axes=fixed_axes,
equal_aspect=equal_aspect,
samples=samples,
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
return output_path
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
def embedding_flow(
num_points: int = 5000,
num_timesteps: int = 48,
jitter_scale: float = 0.01,
seed: int = 42,
generator_path: str = "sklearn.datasets.make_s_curve",
generator_kwargs: Optional[Dict[str, Any]] = None,
embed_columns: Optional[List[str]] = None,
embedder: str = "sklearn.decomposition.FactorAnalysis",
embed_args: Optional[Dict[str, Any]] = None,
output_dir: str = "figs",
id_column: str = "id",
frame_duration: int = 1200,
transition_duration: int = 2400,
reference_speedup: float = 10.0,
samples: int = 10_000,
):
generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS,
**(generator_kwargs or {}),
"n_samples": num_points,
}
embed_columns = (
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
)
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
Path(output_dir).mkdir(parents=True, exist_ok=True)
_generator = generator_path.split(".")[-1]
output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
_args_tag = _embed_args_hash(embed_args)
output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html"
)
output_metrics: str = output_embed[:-5] + ".metrics.json"
output_frames: str = output_embed[:-5] + ".frames.json"
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
merged_embed_args = embed_args
# Generate initial frame using the specified data generator
initial_frame = generate_initial_frame_task.submit(
generator_path=generator_path,
generator_kwargs=generator_kwargs,
id_column=id_column,
)
# Generate snapshots
snapshots = generate_snapshots_task.submit(
initial_df=initial_frame.result(),
num_timesteps=num_timesteps,
jitter_scale=jitter_scale,
seed=seed,
)
snapshot_list = snapshots.result()
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
dates = [
f"{year}-{month:02d}-01"
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
for month in range(1, 13)
][:num_timesteps]
# Apply embeddings in parallel using Prefect's mapping
embeddings = create_embedding.map(
snapshot=snapshot_list,
time_idx=dates,
embed_columns=[embed_columns] * num_timesteps,
embedder=[embedder] * num_timesteps,
embed_args=[merged_embed_args] * num_timesteps,
id_column=[id_column] * num_timesteps,
)
# Collect all embeddings
combined_df = collect_data_task.submit(
embedded_dfs=embeddings.result(), sort_time=False
).result()
# make the original snapshots look like the embeddings
dfr = collect_data_task.submit(
embedded_dfs=snapshot_list, sort_time=False
).result()
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
dfr.columns = ["x", "y", id_column, "time"]
dfr["time"] = combined_df["time"].to_numpy()
# Plot reference animation
ref_path = plot_and_save_task.submit(
combined_df=dfr,
title=title_ref,
output_path=output_ref,
frame_duration=max(frame_duration / reference_speedup, 175),
transition_duration=max(transition_duration / reference_speedup, 350),
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Plot embedding animation
emb_path = plot_and_save_task.submit(
combined_df=combined_df,
title=title_embed,
output_path=output_embed,
frame_duration=frame_duration,
transition_duration=transition_duration,
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Sidecar stability metrics (travel + kNN retention). Runs in parallel
# with plotting; writes a JSON next to the embedding fig.
metrics_path = compute_metrics_task.submit(
snapshot_list=snapshot_list,
embedded_dfs=embeddings.result(),
embed_columns=embed_columns,
output_path=output_metrics,
meta={
"num_points": num_points,
"num_timesteps": num_timesteps,
"jitter_scale": jitter_scale,
"seed": seed,
"generator_path": generator_path,
"embedder": embedder,
"embed_args": merged_embed_args,
},
id_column=id_column,
k=10,
)
emb_path_result = emb_path.result()
metrics_path_result = metrics_path.result()
# Emit a frames.json sidecar so the compare page doesn't have to parse
# the 5 MB plotly HTML on every first request. Non-critical — the server
# falls back to HTML parsing when the sidecar is absent.
try:
import sys as _sys
_root = str(Path(__file__).resolve().parent.parent)
if _root not in _sys.path:
_sys.path.insert(0, _root)
from app.web.plotly_parse import parse_plotly_run
frames = parse_plotly_run(emb_path_result)
Path(output_frames).write_text(
json.dumps(frames, separators=(",", ":")), encoding="utf-8"
)
except Exception as _sidecar_err:
import traceback as _tb
print(f"[frames-sidecar] skipped: {_sidecar_err}")
_tb.print_exc()
return (ref_path.result(), emb_path_result, metrics_path_result)
if __name__ == "__main__":
embedding_flow.serve(limit=1)
# embedding_flow()