metrics stored (2x)
This commit is contained in:
parent
c6bd693058
commit
3280410405
@ -20,11 +20,13 @@ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
|
||||
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
|
||||
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from prefect import flow, task
|
||||
from prefect.artifacts import create_markdown_artifact, create_table_artifact
|
||||
from prefect.cache_policies import INPUTS, NO_CACHE
|
||||
from prefect_ray import RayTaskRunner
|
||||
|
||||
@ -113,6 +115,103 @@ def collect_data_task(
|
||||
)
|
||||
|
||||
|
||||
def _fmt(v: Any, spec: str = ".4f") -> str:
|
||||
return format(v, spec) if isinstance(v, (int, float)) else "—"
|
||||
|
||||
|
||||
def _mean_of(series: List[Dict[str, Any]], key: str) -> Optional[float]:
|
||||
vals = [r[key] for r in series if isinstance(r.get(key), (int, float))]
|
||||
return float(sum(vals) / len(vals)) if vals else None
|
||||
|
||||
|
||||
@task(
|
||||
task_run_name="metrics-{output_path}",
|
||||
retries=1,
|
||||
cache_policy=NO_CACHE,
|
||||
)
|
||||
def compute_metrics_task(
|
||||
snapshot_list: List[pd.DataFrame],
|
||||
embedded_dfs: List[pd.DataFrame],
|
||||
embed_columns: List[str],
|
||||
output_path: str,
|
||||
meta: Dict[str, Any],
|
||||
id_column: str = "id",
|
||||
k: int = 10,
|
||||
) -> str:
|
||||
metrics = E.compute_metrics(
|
||||
snapshot_list=snapshot_list,
|
||||
embedded_list=embedded_dfs,
|
||||
embed_columns=embed_columns,
|
||||
id_column=id_column,
|
||||
k=k,
|
||||
)
|
||||
payload = {"meta": meta, **metrics}
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
|
||||
# --- Prefect artifacts ---
|
||||
ff = metrics["travel"]["frame_to_frame"]
|
||||
vi = metrics["travel"]["vs_initial"]
|
||||
kn = metrics["knn_retention"]
|
||||
|
||||
ff_by_t = {r["t"]: r for r in ff}
|
||||
vi_by_t = {r["t"]: r for r in vi}
|
||||
kn_by_t = {r["t"]: r for r in kn}
|
||||
all_t = sorted(set(ff_by_t) | set(vi_by_t) | set(kn_by_t))
|
||||
rows = []
|
||||
for t in all_t:
|
||||
f_, v_, k_ = ff_by_t.get(t, {}), vi_by_t.get(t, {}), kn_by_t.get(t, {})
|
||||
rows.append({
|
||||
"t": t,
|
||||
"ff_mean": f_.get("mean"), "ff_median": f_.get("median"),
|
||||
"ff_p95": f_.get("p95"), "ff_max": f_.get("max"), "ff_n": f_.get("n_pairs"),
|
||||
"vi_mean": v_.get("mean"), "vi_median": v_.get("median"),
|
||||
"vi_p95": v_.get("p95"), "vi_max": v_.get("max"), "vi_n": v_.get("n_pairs"),
|
||||
"knn_mean": k_.get("mean"), "knn_n": k_.get("n_points"),
|
||||
})
|
||||
|
||||
gen_short = meta["generator_path"].split(".")[-1]
|
||||
emb_short = meta["embedder"].split(".")[-1]
|
||||
desc = (
|
||||
f"`{emb_short}` on `{gen_short}` — "
|
||||
f"N={meta['num_points']} T={meta['num_timesteps']} "
|
||||
f"J={meta['jitter_scale']} s={meta['seed']}"
|
||||
)
|
||||
|
||||
create_table_artifact(
|
||||
key="embedding-metrics",
|
||||
table=rows,
|
||||
description=desc,
|
||||
)
|
||||
|
||||
ff_last = ff[-1] if ff else {}
|
||||
vi_last = vi[-1] if vi else {}
|
||||
md = (
|
||||
f"### {emb_short} on {gen_short}\n\n"
|
||||
f"**N** {meta['num_points']} · **T** {meta['num_timesteps']} · "
|
||||
f"**J** {meta['jitter_scale']} · **seed** {meta['seed']}\n\n"
|
||||
f"| window | mean | median | p95 | max |\n"
|
||||
f"|---|---|---|---|---|\n"
|
||||
f"| frame-to-frame (avg over t) | {_fmt(_mean_of(ff, 'mean'))} | "
|
||||
f"{_fmt(_mean_of(ff, 'median'))} | {_fmt(_mean_of(ff, 'p95'))} | "
|
||||
f"{_fmt(_mean_of(ff, 'max'))} |\n"
|
||||
f"| vs-initial (final t) | {_fmt(vi_last.get('mean'))} | "
|
||||
f"{_fmt(vi_last.get('median'))} | {_fmt(vi_last.get('p95'))} | "
|
||||
f"{_fmt(vi_last.get('max'))} |\n\n"
|
||||
f"**kNN retention** (k={metrics['k']}): "
|
||||
f"mean across timesteps = {_fmt(_mean_of(kn, 'mean'), '.3f')}\n\n"
|
||||
f"_Sidecar JSON:_ `{output_path}`\n"
|
||||
)
|
||||
create_markdown_artifact(
|
||||
key="embedding-metrics-summary",
|
||||
markdown=md,
|
||||
description=desc,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@task(
|
||||
task_run_name="plot-{output_path}",
|
||||
retries=3,
|
||||
@ -183,6 +282,7 @@ def embedding_flow(
|
||||
output_embed: str = (
|
||||
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
||||
)
|
||||
output_metrics: str = output_embed[:-5] + ".metrics.json"
|
||||
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
|
||||
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
|
||||
|
||||
@ -258,7 +358,27 @@ def embedding_flow(
|
||||
samples=samples,
|
||||
)
|
||||
|
||||
return (ref_path.result(), emb_path.result())
|
||||
# Sidecar stability metrics (travel + kNN retention). Runs in parallel
|
||||
# with plotting; writes a JSON next to the embedding fig.
|
||||
metrics_path = compute_metrics_task.submit(
|
||||
snapshot_list=snapshot_list,
|
||||
embedded_dfs=embeddings.result(),
|
||||
embed_columns=embed_columns,
|
||||
output_path=output_metrics,
|
||||
meta={
|
||||
"num_points": num_points,
|
||||
"num_timesteps": num_timesteps,
|
||||
"jitter_scale": jitter_scale,
|
||||
"seed": seed,
|
||||
"generator_path": generator_path,
|
||||
"embedder": embedder,
|
||||
"embed_args": merged_embed_args,
|
||||
},
|
||||
id_column=id_column,
|
||||
k=10,
|
||||
)
|
||||
|
||||
return (ref_path.result(), emb_path.result(), metrics_path.result())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# embedding_utils.py
|
||||
|
||||
import importlib
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -423,6 +423,115 @@ def generate_initial_frame(
|
||||
return df
|
||||
|
||||
|
||||
def _travel_stats(d: np.ndarray) -> Dict[str, Any]:
|
||||
"""Summary stats for a 1-D vector of displacements."""
|
||||
if d.size == 0:
|
||||
return {"mean": None, "median": None, "p95": None, "max": None, "n_pairs": 0}
|
||||
return {
|
||||
"mean": float(np.mean(d)),
|
||||
"median": float(np.median(d)),
|
||||
"p95": float(np.percentile(d, 95)),
|
||||
"max": float(np.max(d)),
|
||||
"n_pairs": int(d.size),
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
snapshot_list: List[pd.DataFrame],
|
||||
embedded_list: List[pd.DataFrame],
|
||||
embed_columns: List[str],
|
||||
id_column: str = "id",
|
||||
k: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Per-timestep stability metrics for a DR run.
|
||||
|
||||
- travel.frame_to_frame[t]: point displacement in 2-D output space
|
||||
between timestep t-1 and t (inner-joined on id).
|
||||
- travel.vs_initial[t]: displacement in 2-D output space between
|
||||
timestep 0 and t (inner-joined on id).
|
||||
- knn_retention[t]: mean per-point fraction of k nearest input-space
|
||||
neighbours that remain neighbours in output space.
|
||||
|
||||
Input: snapshot_list and embedded_list have len == num_timesteps and
|
||||
align index-wise. Each snapshot carries {embed_columns}+id+time; each
|
||||
embedding carries id+x+y+time. ids may differ across timesteps (the
|
||||
generator adds/removes points), so all joins are inner on id.
|
||||
"""
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
assert len(snapshot_list) == len(embedded_list), "snapshot/embedding count mismatch"
|
||||
T = len(embedded_list)
|
||||
|
||||
emb_by_id = [
|
||||
df.drop_duplicates(subset=[id_column]).set_index(id_column)[["x", "y"]]
|
||||
for df in embedded_list
|
||||
]
|
||||
|
||||
# --- travel ---
|
||||
frame_to_frame: List[Dict[str, Any]] = []
|
||||
vs_initial: List[Dict[str, Any]] = []
|
||||
initial = emb_by_id[0] if T > 0 else None
|
||||
|
||||
for t in range(1, T):
|
||||
curr = emb_by_id[t]
|
||||
prev = emb_by_id[t - 1]
|
||||
|
||||
common_ff = curr.index.intersection(prev.index)
|
||||
d_ff = np.linalg.norm(
|
||||
curr.loc[common_ff].to_numpy() - prev.loc[common_ff].to_numpy(), axis=1
|
||||
) if len(common_ff) else np.array([])
|
||||
frame_to_frame.append({"t": t, **_travel_stats(d_ff)})
|
||||
|
||||
common_vi = curr.index.intersection(initial.index)
|
||||
d_vi = np.linalg.norm(
|
||||
curr.loc[common_vi].to_numpy() - initial.loc[common_vi].to_numpy(), axis=1
|
||||
) if len(common_vi) else np.array([])
|
||||
vs_initial.append({"t": t, **_travel_stats(d_vi)})
|
||||
|
||||
# --- kNN retention ---
|
||||
knn_retention: List[Dict[str, Any]] = []
|
||||
for t in range(T):
|
||||
snap = snapshot_list[t].drop_duplicates(subset=[id_column])
|
||||
emb = embedded_list[t].drop_duplicates(subset=[id_column])
|
||||
common = snap[id_column].values
|
||||
common = np.intersect1d(common, emb[id_column].values)
|
||||
n = len(common)
|
||||
k_eff = min(k, n - 1) if n > 1 else 0
|
||||
if k_eff <= 0:
|
||||
knn_retention.append({"t": t, "k": k_eff, "mean": None, "n_points": int(n)})
|
||||
continue
|
||||
|
||||
snap_idx = snap.set_index(id_column).loc[common]
|
||||
emb_idx = emb.set_index(id_column).loc[common]
|
||||
X_in = snap_idx[embed_columns].to_numpy()
|
||||
X_out = emb_idx[["x", "y"]].to_numpy()
|
||||
|
||||
nn_in = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_in)
|
||||
nn_out = NearestNeighbors(n_neighbors=k_eff + 1).fit(X_out)
|
||||
idx_in = nn_in.kneighbors(X_in, return_distance=False)[:, 1:]
|
||||
idx_out = nn_out.kneighbors(X_out, return_distance=False)[:, 1:]
|
||||
|
||||
# per-row intersection count via broadcast equality
|
||||
matches = (idx_out[:, :, None] == idx_in[:, None, :]).any(axis=2).sum(axis=1)
|
||||
retentions = matches / k_eff
|
||||
knn_retention.append({
|
||||
"t": t,
|
||||
"k": int(k_eff),
|
||||
"mean": float(np.mean(retentions)),
|
||||
"n_points": int(n),
|
||||
})
|
||||
|
||||
return {
|
||||
"k": int(k),
|
||||
"travel": {
|
||||
"frame_to_frame": frame_to_frame,
|
||||
"vs_initial": vs_initial,
|
||||
},
|
||||
"knn_retention": knn_retention,
|
||||
}
|
||||
|
||||
|
||||
def generate_jittered_snapshots(
|
||||
initial_df: pd.DataFrame,
|
||||
num_timesteps: int,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user