diff --git a/flows/embedding_flow.py b/flows/embedding_flow.py index 103c460..2661e09 100644 --- a/flows/embedding_flow.py +++ b/flows/embedding_flow.py @@ -20,11 +20,13 @@ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") os.environ.setdefault("NUMBA_NUM_THREADS", "1") from datetime import timedelta +import json import math from pathlib import Path from typing import Any, Dict, List, Optional from prefect import flow, task +from prefect.artifacts import create_markdown_artifact, create_table_artifact from prefect.cache_policies import INPUTS, NO_CACHE from prefect_ray import RayTaskRunner @@ -113,6 +115,103 @@ def collect_data_task( ) +def _fmt(v: Any, spec: str = ".4f") -> str: + return format(v, spec) if isinstance(v, (int, float)) else "—" + + +def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]: + vals = [r[key] for r in series if isinstance(r.get(key), (int, float))] + return float(sum(vals) / len(vals)) if vals else None + + +@task( + task_run_name="metrics-{output_path}", + retries=1, + cache_policy=NO_CACHE, +) +def compute_metrics_task( + snapshot_list: List[pd.DataFrame], + embedded_dfs: List[pd.DataFrame], + embed_columns: List[str], + output_path: str, + meta: Dict[str, Any], + id_column: str = "id", + k: int = 10, +) -> str: + metrics = E.compute_metrics( + snapshot_list=snapshot_list, + embedded_list=embedded_dfs, + embed_columns=embed_columns, + id_column=id_column, + k=k, + ) + payload = {"meta": meta, **metrics} + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(payload, f, indent=2) + + # --- Prefect artifacts --- + ff = metrics["travel"]["frame_to_frame"] + vi = metrics["travel"]["vs_initial"] + kn = metrics["knn_retention"] + + ff_by_t = {r["t"]: r for r in ff} + vi_by_t = {r["t"]: r for r in vi} + kn_by_t = {r["t"]: r for r in kn} + all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t)) + rows = [] + for t in all_t: + f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {}) + rows.append({ + "t": t, + "ff_mean": f_.get("mean"), "ff_median": f_.get("median"), + "ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"), + "vi_mean": v_.get("mean"), "vi_median": v_.get("median"), + "vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"), + "knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"), + }) + + gen_short = meta["generator_path"].split(".")[-1] + emb_short = meta["embedder"].split(".")[-1] + desc = ( + f"`{emb_short}` on `{gen_short}` — " + f"N={meta['num_points']} T={meta['num_timesteps']} " + f"J={meta['jitter_scale']} s={meta['seed']}" + ) + + create_table_artifact( + key="embedding-metrics", + table=rows, + description=desc, + ) + + ff_last = ff[-1] if ff else {} + vi_last = vi[-1] if vi else {} + md = ( + f"### {emb_short} on {gen_short}\n\n" + f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · " + f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n" + f"| window | mean | median | p95 | max |\n" + f"|---|---|---|---|---|\n" + f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | " + f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | " + f"{_fmt(_mean_of(ff, 'max'))} |\n" + f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | " + f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | " + f"{_fmt(vi_last.get('max'))} |\n\n" + f"**kNN retention** (k={metrics['k']}): " + f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n" + f"_Sidecar JSON:_ `{output_path}`\n" + ) + create_markdown_artifact( + key="embedding-metrics-summary", + markdown=md, + description=desc, + ) + + return output_path + + @task( task_run_name="plot-{output_path}", retries=3, @@ -183,6 +282,7 @@ def embedding_flow( output_embed: str = ( f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" ) + output_metrics: str = output_embed[:-5] + ".metrics.json" title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise" title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise" @@ -258,7 +358,27 @@ def embedding_flow( samples=samples, ) - return (ref_path.result(), emb_path.result()) + # Sidecar stability metrics (travel + kNN retention). Runs in parallel + # with plotting; writes a JSON next to the embedding fig. + metrics_path = compute_metrics_task.submit( + snapshot_list=snapshot_list, + embedded_dfs=embeddings.result(), + embed_columns=embed_columns, + output_path=output_metrics, + meta={ + "num_points": num_points, + "num_timesteps": num_timesteps, + "jitter_scale": jitter_scale, + "seed": seed, + "generator_path": generator_path, + "embedder": embedder, + "embed_args": merged_embed_args, + }, + id_column=id_column, + k=10, + ) + + return (ref_path.result(), emb_path.result(), metrics_path.result()) if __name__ == "__main__": diff --git a/flows/embedding_utils.py b/flows/embedding_utils.py index b3f9888..4100c34 100644 --- a/flows/embedding_utils.py +++ b/flows/embedding_utils.py @@ -1,7 +1,7 @@ # embedding_utils.py import importlib -from typing import List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import numpy as np import pandas as pd @@ -423,6 +423,115 @@ def generate_initial_frame( return df +def _travel_stats(d: np.ndarray) -> Dict[str, Any]: + """Summary stats for a 1-D vector of displacements.""" + if d.size == 0: + return {"mean": None, "median": None, "p95": None, "max": None, "n_pairs": 0} + return { + "mean": float(np.mean(d)), + "median": float(np.median(d)), + "p95": float(np.percentile(d, 95)), + "max": float(np.max(d)), + "n_pairs": int(d.size), + } + + +def compute_metrics( + snapshot_list: List[pd.DataFrame], + embedded_list: List[pd.DataFrame], + embed_columns: List[str], + id_column: str = "id", + k: int = 10, +) -> Dict[str, Any]: + """ + Per-timestep stability metrics for a DR run. + + - travel.frame_to_frame[t]: point displacement in 2-D output space + between timestep t-1 and t (inner-joined on id). + - travel.vs_initial[t]: displacement in 2-D output space between + timestep 0 and t (inner-joined on id). + - knn_retention[t]: mean per-point fraction of k nearest input-space + neighbours that remain neighbours in output space. + + Input: snapshot_list and embedded_list have len == num_timesteps and + align index-wise. Each snapshot carries {embed_columns}+id+time; each + embedding carries id+x+y+time. ids may differ across timesteps (the + generator adds/removes points), so all joins are inner on id. + """ + from sklearn.neighbors import NearestNeighbors + + assert len(snapshot_list) == len(embedded_list), "snapshot/embedding count mismatch" + T = len(embedded_list) + + emb_by_id = [ + df.drop_duplicates(subset=[id_column]).set_index(id_column)[["x", "y"]] + for df in embedded_list + ] + + # --- travel --- + frame_to_frame: List[Dict[str, Any]] = [] + vs_initial: List[Dict[str, Any]] = [] + initial = emb_by_id[0] if T > 0 else None + + for t in range(1, T): + curr = emb_by_id[t] + prev = emb_by_id[t - 1] + + common_ff = curr.index.intersection(prev.index) + d_ff = np.linalg.norm( + curr.loc[common_ff].to_numpy() - prev.loc[common_ff].to_numpy(), axis=1 + ) if len(common_ff) else np.array([]) + frame_to_frame.append({"t": t, **_travel_stats(d_ff)}) + + common_vi = curr.index.intersection(initial.index) + d_vi = np.linalg.norm( + curr.loc[common_vi].to_numpy() - initial.loc[common_vi].to_numpy(), axis=1 + ) if len(common_vi) else np.array([]) + vs_initial.append({"t": t, **_travel_stats(d_vi)}) + + # --- kNN retention --- + knn_retention: List[Dict[str, Any]] = [] + for t in range(T): + snap = snapshot_list[t].drop_duplicates(subset=[id_column]) + emb = embedded_list[t].drop_duplicates(subset=[id_column]) + common = snap[id_column].values + common = np.intersect1d(common, emb[id_column].values) + n = len(common) + k_eff = min(k, n - 1) if n > 1 else 0 + if k_eff <= 0: + knn_retention.append({"t": t, "k": k_eff, "mean": None, "n_points": int(n)}) + continue + + snap_idx = snap.set_index(id_column).loc[common] + emb_idx = emb.set_index(id_column).loc[common] + X_in = snap_idx[embed_columns].to_numpy() + X_out = emb_idx[["x", "y"]].to_numpy() + + nn_in = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_in) + nn_out = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_out) + idx_in = nn_in.kneighbors(X_in, return_distance=False)[:, 1:] + idx_out = nn_out.kneighbors(X_out, return_distance=False)[:, 1:] + + # per-row intersection count via broadcast equality + matches = (idx_out[:, :, None] == idx_in[:, None, :]).any(axis=2).sum(axis=1) + retentions = matches / k_eff + knn_retention.append({ + "t": t, + "k": int(k_eff), + "mean": float(np.mean(retentions)), + "n_points": int(n), + }) + + return { + "k": int(k), + "travel": { + "frame_to_frame": frame_to_frame, + "vs_initial": vs_initial, + }, + "knn_retention": knn_retention, + } + + def generate_jittered_snapshots( initial_df: pd.DataFrame, num_timesteps: int,