Stem grows an 8-hex sha1 digest of the (keys-sorted) embed_args dict, so runs differing only in embed_args (e.g. UMAP n_neighbors=5 vs 15) now produce distinct figs. The stem regex and parser both accept an optional _<hash> tail so pre-hash figs still render in the runs list and compare page; legacy filename is resolved on disk fallback. Duplicate-submission check now rejects against BOTH the hashed and the legacy hashless variant so users can't accidentally duplicate an old run either. Flow additionally writes a <stem>.frames.json sidecar next to the plotly HTML (same shape as app/web/plotly_parse returns). Server prefers the sidecar when present; falls back to parsing HTML for older runs. Sidecar emission is non-critical — any failure just logs and keeps going.
418 lines
14 KiB
Python
418 lines
14 KiB
Python
# embedding_flow.py
|
||
import os
|
||
import sys
|
||
|
||
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
|
||
# in the environment still wins (setdefault is a no-op if the key exists).
|
||
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
|
||
os.environ.setdefault("DO_NOT_TRACK", "1")
|
||
|
||
# Pin per-process thread pools to 1 so Ray's worker parallelism doesn't
|
||
# multiply against BLAS/numba/etc. thread pools — otherwise 4 workers × N
|
||
# cores → thrash. Must be set before numpy/numba/sklearn import, since
|
||
# those libs latch onto these env vars at import time. Ray manages OMP
|
||
# per-task-CPU but does NOT manage NUMBA_NUM_THREADS, which is what
|
||
# PaCMAP/UMAP use for their optimization loops.
|
||
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
||
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
||
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
||
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
|
||
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
|
||
|
||
from datetime import timedelta
|
||
import hashlib
|
||
import json
|
||
import math
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
|
||
def _embed_args_hash(ea: Optional[Dict[str, Any]]) -> str:
|
||
"""8-hex digest of embed_args (keys sorted) — output stem includes this
|
||
so runs differing only in embed_args get distinct files."""
|
||
s = json.dumps(ea or {}, sort_keys=True, default=str)
|
||
return hashlib.sha1(s.encode()).hexdigest()[:8]
|
||
|
||
from prefect import flow, task
|
||
from prefect.artifacts import create_markdown_artifact, create_table_artifact
|
||
from prefect.cache_policies import INPUTS, NO_CACHE
|
||
from prefect_ray import RayTaskRunner
|
||
|
||
import pandas as pd
|
||
from sklearn.preprocessing import StandardScaler
|
||
|
||
import embedding_utils as E
|
||
from joblib import cpu_count
|
||
|
||
|
||
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
|
||
def generate_initial_frame_task(
|
||
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
|
||
) -> pd.DataFrame:
|
||
"""
|
||
Generate the initial data frame using a specified data generator.
|
||
|
||
Parameters:
|
||
- generator_path: str
|
||
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
|
||
- generator_kwargs: Dict[str, Any]
|
||
Keyword arguments to pass to the data generator function.
|
||
- id_column: str
|
||
Column name to use as a unique identifier.
|
||
|
||
Returns:
|
||
- df: pd.DataFrame
|
||
DataFrame with generated data and unique IDs.
|
||
"""
|
||
generator_func = E.dynamic_import(generator_path)
|
||
data, labels = generator_func(**generator_kwargs)
|
||
|
||
# Per-feature z-score so jitter_scale has consistent meaning across
|
||
# generators and reducers see comparably-scaled inputs.
|
||
data = StandardScaler().fit_transform(data)
|
||
|
||
df = pd.DataFrame(
|
||
{
|
||
"feature_0": data[:, 0],
|
||
"feature_1": data[:, 1],
|
||
"feature_2": data[:, 2],
|
||
id_column: range(data.shape[0]),
|
||
"time": 0,
|
||
}
|
||
)
|
||
df[id_column] = df[id_column].astype(int)
|
||
return df
|
||
|
||
|
||
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
|
||
def generate_snapshots_task(
|
||
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
|
||
) -> List[pd.DataFrame]:
|
||
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
|
||
|
||
|
||
@task(
|
||
cache_policy=INPUTS,
|
||
cache_expiration=timedelta(days=1),
|
||
task_run_name="embed-{time_idx}",
|
||
)
|
||
def create_embedding(
|
||
snapshot: pd.DataFrame,
|
||
embed_columns: List[str],
|
||
embedder: str,
|
||
embed_args: Dict[str, Any],
|
||
time_idx: str | int,
|
||
id_column: str = "id",
|
||
) -> pd.DataFrame:
|
||
return E.create_embedding_dataframe(
|
||
snapshot=snapshot,
|
||
embed_columns=embed_columns,
|
||
embedding_algorithm_str=embedder,
|
||
embedding_kwargs=embed_args,
|
||
id_column=id_column,
|
||
time_idx=time_idx,
|
||
)
|
||
|
||
|
||
@task
|
||
def collect_data_task(
|
||
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
|
||
) -> pd.DataFrame:
|
||
return E.collect_and_prepare_for_plotly(
|
||
embedded_dfs, sort_time=sort_time, id_column=id_column
|
||
)
|
||
|
||
|
||
def _fmt(v: Any, spec: str = ".4f") -> str:
|
||
return format(v, spec) if isinstance(v, (int, float)) else "—"
|
||
|
||
|
||
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
|
||
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
|
||
return float(sum(vals) / len(vals)) if vals else None
|
||
|
||
|
||
@task(
|
||
task_run_name="metrics-{output_path}",
|
||
retries=1,
|
||
cache_policy=NO_CACHE,
|
||
)
|
||
def compute_metrics_task(
|
||
snapshot_list: List[pd.DataFrame],
|
||
embedded_dfs: List[pd.DataFrame],
|
||
embed_columns: List[str],
|
||
output_path: str,
|
||
meta: Dict[str, Any],
|
||
id_column: str = "id",
|
||
k: int = 10,
|
||
) -> str:
|
||
metrics = E.compute_metrics(
|
||
snapshot_list=snapshot_list,
|
||
embedded_list=embedded_dfs,
|
||
embed_columns=embed_columns,
|
||
id_column=id_column,
|
||
k=k,
|
||
)
|
||
payload = {"meta": meta, **metrics}
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w") as f:
|
||
json.dump(payload, f, indent=2)
|
||
|
||
# --- Prefect artifacts ---
|
||
ff = metrics["travel"]["frame_to_frame"]
|
||
vi = metrics["travel"]["vs_initial"]
|
||
kn = metrics["knn_retention"]
|
||
|
||
ff_by_t = {r["t"]: r for r in ff}
|
||
vi_by_t = {r["t"]: r for r in vi}
|
||
kn_by_t = {r["t"]: r for r in kn}
|
||
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
|
||
rows = []
|
||
for t in all_t:
|
||
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
|
||
rows.append({
|
||
"t": t,
|
||
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
|
||
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
|
||
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
|
||
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
|
||
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
|
||
})
|
||
|
||
gen_short = meta["generator_path"].split(".")[-1]
|
||
emb_short = meta["embedder"].split(".")[-1]
|
||
desc = (
|
||
f"`{emb_short}` on `{gen_short}` — "
|
||
f"N={meta['num_points']} T={meta['num_timesteps']} "
|
||
f"J={meta['jitter_scale']} s={meta['seed']}"
|
||
)
|
||
|
||
create_table_artifact(
|
||
key="embedding-metrics",
|
||
table=rows,
|
||
description=desc,
|
||
)
|
||
|
||
ff_last = ff[-1] if ff else {}
|
||
vi_last = vi[-1] if vi else {}
|
||
md = (
|
||
f"### {emb_short} on {gen_short}\n\n"
|
||
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
|
||
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
|
||
f"| window | mean | median | p95 | max |\n"
|
||
f"|---|---|---|---|---|\n"
|
||
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
|
||
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
|
||
f"{_fmt(_mean_of(ff, 'max'))} |\n"
|
||
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
|
||
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
|
||
f"{_fmt(vi_last.get('max'))} |\n\n"
|
||
f"**kNN retention** (k={metrics['k']}): "
|
||
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
|
||
f"_Sidecar JSON:_ `{output_path}`\n"
|
||
)
|
||
create_markdown_artifact(
|
||
key="embedding-metrics-summary",
|
||
markdown=md,
|
||
description=desc,
|
||
)
|
||
|
||
return output_path
|
||
|
||
|
||
@task(
|
||
task_run_name="plot-{output_path}",
|
||
retries=3,
|
||
cache_policy=NO_CACHE,
|
||
)
|
||
def plot_and_save_task(
|
||
combined_df: pd.DataFrame,
|
||
title: str,
|
||
output_path: str,
|
||
frame_duration: int = 500,
|
||
transition_duration: int = 500,
|
||
fixed_axes: bool = True,
|
||
equal_aspect: bool = True,
|
||
samples: int = 25_000,
|
||
):
|
||
fig = E.plot_embedding_over_time(
|
||
combined_df,
|
||
title=title,
|
||
frame_duration=int(frame_duration),
|
||
transition_duration=int(transition_duration),
|
||
fixed_axes=fixed_axes,
|
||
equal_aspect=equal_aspect,
|
||
samples=samples,
|
||
)
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
fig.write_html(output_path)
|
||
return output_path
|
||
|
||
|
||
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
|
||
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
|
||
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
|
||
|
||
|
||
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
|
||
def embedding_flow(
|
||
num_points: int = 5000,
|
||
num_timesteps: int = 48,
|
||
jitter_scale: float = 0.01,
|
||
seed: int = 42,
|
||
generator_path: str = "sklearn.datasets.make_s_curve",
|
||
generator_kwargs: Optional[Dict[str, Any]] = None,
|
||
embed_columns: Optional[List[str]] = None,
|
||
embedder: str = "sklearn.decomposition.FactorAnalysis",
|
||
embed_args: Optional[Dict[str, Any]] = None,
|
||
output_dir: str = "figs",
|
||
id_column: str = "id",
|
||
frame_duration: int = 1200,
|
||
transition_duration: int = 2400,
|
||
reference_speedup: float = 10.0,
|
||
samples: int = 10_000,
|
||
):
|
||
generator_kwargs = {
|
||
**_DEFAULT_GENERATOR_KWARGS,
|
||
**(generator_kwargs or {}),
|
||
"n_samples": num_points,
|
||
}
|
||
embed_columns = (
|
||
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
|
||
)
|
||
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
|
||
|
||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||
_generator = generator_path.split(".")[-1]
|
||
output_ref: str = (
|
||
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
||
)
|
||
_args_tag = _embed_args_hash(embed_args)
|
||
output_embed: str = (
|
||
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html"
|
||
)
|
||
output_metrics: str = output_embed[:-5] + ".metrics.json"
|
||
output_frames: str = output_embed[:-5] + ".frames.json"
|
||
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
|
||
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
|
||
|
||
merged_embed_args = embed_args
|
||
|
||
# Generate initial frame using the specified data generator
|
||
initial_frame = generate_initial_frame_task.submit(
|
||
generator_path=generator_path,
|
||
generator_kwargs=generator_kwargs,
|
||
id_column=id_column,
|
||
)
|
||
|
||
# Generate snapshots
|
||
snapshots = generate_snapshots_task.submit(
|
||
initial_df=initial_frame.result(),
|
||
num_timesteps=num_timesteps,
|
||
jitter_scale=jitter_scale,
|
||
seed=seed,
|
||
)
|
||
snapshot_list = snapshots.result()
|
||
|
||
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
|
||
dates = [
|
||
f"{year}-{month:02d}-01"
|
||
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
|
||
for month in range(1, 13)
|
||
][:num_timesteps]
|
||
|
||
# Apply embeddings in parallel using Prefect's mapping
|
||
embeddings = create_embedding.map(
|
||
snapshot=snapshot_list,
|
||
time_idx=dates,
|
||
embed_columns=[embed_columns] * num_timesteps,
|
||
embedder=[embedder] * num_timesteps,
|
||
embed_args=[merged_embed_args] * num_timesteps,
|
||
id_column=[id_column] * num_timesteps,
|
||
)
|
||
|
||
# Collect all embeddings
|
||
combined_df = collect_data_task.submit(
|
||
embedded_dfs=embeddings.result(), sort_time=False
|
||
).result()
|
||
|
||
# make the original snapshots look like the embeddings
|
||
dfr = collect_data_task.submit(
|
||
embedded_dfs=snapshot_list, sort_time=False
|
||
).result()
|
||
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
|
||
dfr.columns = ["x", "y", id_column, "time"]
|
||
dfr["time"] = combined_df["time"].to_numpy()
|
||
|
||
# Plot reference animation
|
||
ref_path = plot_and_save_task.submit(
|
||
combined_df=dfr,
|
||
title=title_ref,
|
||
output_path=output_ref,
|
||
frame_duration=max(frame_duration / reference_speedup, 175),
|
||
transition_duration=max(transition_duration / reference_speedup, 350),
|
||
fixed_axes=True,
|
||
equal_aspect=False,
|
||
samples=samples,
|
||
)
|
||
|
||
# Plot embedding animation
|
||
emb_path = plot_and_save_task.submit(
|
||
combined_df=combined_df,
|
||
title=title_embed,
|
||
output_path=output_embed,
|
||
frame_duration=frame_duration,
|
||
transition_duration=transition_duration,
|
||
fixed_axes=True,
|
||
equal_aspect=False,
|
||
samples=samples,
|
||
)
|
||
|
||
# Sidecar stability metrics (travel + kNN retention). Runs in parallel
|
||
# with plotting; writes a JSON next to the embedding fig.
|
||
metrics_path = compute_metrics_task.submit(
|
||
snapshot_list=snapshot_list,
|
||
embedded_dfs=embeddings.result(),
|
||
embed_columns=embed_columns,
|
||
output_path=output_metrics,
|
||
meta={
|
||
"num_points": num_points,
|
||
"num_timesteps": num_timesteps,
|
||
"jitter_scale": jitter_scale,
|
||
"seed": seed,
|
||
"generator_path": generator_path,
|
||
"embedder": embedder,
|
||
"embed_args": merged_embed_args,
|
||
},
|
||
id_column=id_column,
|
||
k=10,
|
||
)
|
||
|
||
emb_path_result = emb_path.result()
|
||
metrics_path_result = metrics_path.result()
|
||
|
||
# Emit a frames.json sidecar so the compare page doesn't have to parse
|
||
# the 5 MB plotly HTML on every first request. Non-critical — the server
|
||
# falls back to HTML parsing when the sidecar is absent.
|
||
try:
|
||
import sys as _sys
|
||
_root = str(Path(__file__).resolve().parent.parent)
|
||
if _root not in _sys.path:
|
||
_sys.path.insert(0, _root)
|
||
from app.web.plotly_parse import parse_plotly_run
|
||
frames = parse_plotly_run(emb_path_result)
|
||
Path(output_frames).write_text(
|
||
json.dumps(frames, separators=(",", ":")), encoding="utf-8"
|
||
)
|
||
except Exception as _sidecar_err:
|
||
import traceback as _tb
|
||
print(f"[frames-sidecar] skipped: {_sidecar_err}")
|
||
_tb.print_exc()
|
||
|
||
return (ref_path.result(), emb_path_result, metrics_path_result)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
embedding_flow.serve(limit=1)
|
||
# embedding_flow()
|