metrics stored (2x)

This commit is contained in:
Michael Pilosov 2026-04-21 20:41:17 -06:00
parent c6bd693058
commit 3280410405
2 changed files with 231 additions and 2 deletions

View File

@ -20,11 +20,13 @@ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("NUMBA_NUM_THREADS", "1") os.environ.setdefault("NUMBA_NUM_THREADS", "1")
from datetime import timedelta from datetime import timedelta
import json
import math import math
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from prefect import flow, task from prefect import flow, task
from prefect.artifacts import create_markdown_artifact, create_table_artifact
from prefect.cache_policies import INPUTS, NO_CACHE from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner from prefect_ray import RayTaskRunner
@ -113,6 +115,103 @@ def collect_data_task(
) )
def _fmt(v: Any, spec: str = ".4f") -> str:
return format(v, spec) if isinstance(v, (int, float)) else ""
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
return float(sum(vals) / len(vals)) if vals else None
@task(
task_run_name="metrics-{output_path}",
retries=1,
cache_policy=NO_CACHE,
)
def compute_metrics_task(
snapshot_list: List[pd.DataFrame],
embedded_dfs: List[pd.DataFrame],
embed_columns: List[str],
output_path: str,
meta: Dict[str, Any],
id_column: str = "id",
k: int = 10,
) -> str:
metrics = E.compute_metrics(
snapshot_list=snapshot_list,
embedded_list=embedded_dfs,
embed_columns=embed_columns,
id_column=id_column,
k=k,
)
payload = {"meta": meta, **metrics}
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(payload, f, indent=2)
# --- Prefect artifacts ---
ff = metrics["travel"]["frame_to_frame"]
vi = metrics["travel"]["vs_initial"]
kn = metrics["knn_retention"]
ff_by_t = {r["t"]: r for r in ff}
vi_by_t = {r["t"]: r for r in vi}
kn_by_t = {r["t"]: r for r in kn}
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
rows = []
for t in all_t:
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
rows.append({
"t": t,
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
})
gen_short = meta["generator_path"].split(".")[-1]
emb_short = meta["embedder"].split(".")[-1]
desc = (
f"`{emb_short}` on `{gen_short}` — "
f"N={meta['num_points']} T={meta['num_timesteps']} "
f"J={meta['jitter_scale']} s={meta['seed']}"
)
create_table_artifact(
key="embedding-metrics",
table=rows,
description=desc,
)
ff_last = ff[-1] if ff else {}
vi_last = vi[-1] if vi else {}
md = (
f"### {emb_short} on {gen_short}\n\n"
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
f"| window | mean | median | p95 | max |\n"
f"|---|---|---|---|---|\n"
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
f"{_fmt(_mean_of(ff, 'max'))} |\n"
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
f"{_fmt(vi_last.get('max'))} |\n\n"
f"**kNN retention** (k={metrics['k']}): "
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
f"_Sidecar JSON:_ `{output_path}`\n"
)
create_markdown_artifact(
key="embedding-metrics-summary",
markdown=md,
description=desc,
)
return output_path
@task( @task(
task_run_name="plot-{output_path}", task_run_name="plot-{output_path}",
retries=3, retries=3,
@ -183,6 +282,7 @@ def embedding_flow(
output_embed: str = ( output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
) )
output_metrics: str = output_embed[:-5] + ".metrics.json"
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise" title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise" title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
@ -258,7 +358,27 @@ def embedding_flow(
samples=samples, samples=samples,
) )
return (ref_path.result(), emb_path.result()) # Sidecar stability metrics (travel + kNN retention). Runs in parallel
# with plotting; writes a JSON next to the embedding fig.
metrics_path = compute_metrics_task.submit(
snapshot_list=snapshot_list,
embedded_dfs=embeddings.result(),
embed_columns=embed_columns,
output_path=output_metrics,
meta={
"num_points": num_points,
"num_timesteps": num_timesteps,
"jitter_scale": jitter_scale,
"seed": seed,
"generator_path": generator_path,
"embedder": embedder,
"embed_args": merged_embed_args,
},
id_column=id_column,
k=10,
)
return (ref_path.result(), emb_path.result(), metrics_path.result())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# embedding_utils.py # embedding_utils.py
import importlib import importlib
from typing import List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -423,6 +423,115 @@ def generate_initial_frame(
return df return df
def _travel_stats(d: np.ndarray) -> Dict[str, Any]:
"""Summary stats for a 1-D vector of displacements."""
if d.size == 0:
return {"mean": None, "median": None, "p95": None, "max": None, "n_pairs": 0}
return {
"mean": float(np.mean(d)),
"median": float(np.median(d)),
"p95": float(np.percentile(d, 95)),
"max": float(np.max(d)),
"n_pairs": int(d.size),
}
def compute_metrics(
snapshot_list: List[pd.DataFrame],
embedded_list: List[pd.DataFrame],
embed_columns: List[str],
id_column: str = "id",
k: int = 10,
) -> Dict[str, Any]:
"""
Per-timestep stability metrics for a DR run.
- travel.frame_to_frame[t]: point displacement in 2-D output space
between timestep t-1 and t (inner-joined on id).
- travel.vs_initial[t]: displacement in 2-D output space between
timestep 0 and t (inner-joined on id).
- knn_retention[t]: mean per-point fraction of k nearest input-space
neighbours that remain neighbours in output space.
Input: snapshot_list and embedded_list have len == num_timesteps and
align index-wise. Each snapshot carries {embed_columns}+id+time; each
embedding carries id+x+y+time. ids may differ across timesteps (the
generator adds/removes points), so all joins are inner on id.
"""
from sklearn.neighbors import NearestNeighbors
assert len(snapshot_list) == len(embedded_list), "snapshot/embedding count mismatch"
T = len(embedded_list)
emb_by_id = [
df.drop_duplicates(subset=[id_column]).set_index(id_column)[["x", "y"]]
for df in embedded_list
]
# --- travel ---
frame_to_frame: List[Dict[str, Any]] = []
vs_initial: List[Dict[str, Any]] = []
initial = emb_by_id[0] if T > 0 else None
for t in range(1, T):
curr = emb_by_id[t]
prev = emb_by_id[t - 1]
common_ff = curr.index.intersection(prev.index)
d_ff = np.linalg.norm(
curr.loc[common_ff].to_numpy() - prev.loc[common_ff].to_numpy(), axis=1
) if len(common_ff) else np.array([])
frame_to_frame.append({"t": t, **_travel_stats(d_ff)})
common_vi = curr.index.intersection(initial.index)
d_vi = np.linalg.norm(
curr.loc[common_vi].to_numpy() - initial.loc[common_vi].to_numpy(), axis=1
) if len(common_vi) else np.array([])
vs_initial.append({"t": t, **_travel_stats(d_vi)})
# --- kNN retention ---
knn_retention: List[Dict[str, Any]] = []
for t in range(T):
snap = snapshot_list[t].drop_duplicates(subset=[id_column])
emb = embedded_list[t].drop_duplicates(subset=[id_column])
common = snap[id_column].values
common = np.intersect1d(common, emb[id_column].values)
n = len(common)
k_eff = min(k, n - 1) if n > 1 else 0
if k_eff <= 0:
knn_retention.append({"t": t, "k": k_eff, "mean": None, "n_points": int(n)})
continue
snap_idx = snap.set_index(id_column).loc[common]
emb_idx = emb.set_index(id_column).loc[common]
X_in = snap_idx[embed_columns].to_numpy()
X_out = emb_idx[["x", "y"]].to_numpy()
nn_in = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_in)
nn_out = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_out)
idx_in = nn_in.kneighbors(X_in, return_distance=False)[:, 1:]
idx_out = nn_out.kneighbors(X_out, return_distance=False)[:, 1:]
# per-row intersection count via broadcast equality
matches = (idx_out[:, :, None] == idx_in[:, None, :]).any(axis=2).sum(axis=1)
retentions = matches / k_eff
knn_retention.append({
"t": t,
"k": int(k_eff),
"mean": float(np.mean(retentions)),
"n_points": int(n),
})
return {
"k": int(k),
"travel": {
"frame_to_frame": frame_to_frame,
"vs_initial": vs_initial,
},
"knn_retention": knn_retention,
}
def generate_jittered_snapshots( def generate_jittered_snapshots(
initial_df: pd.DataFrame, initial_df: pd.DataFrame,
num_timesteps: int, num_timesteps: int,