dr-sandbox/flows/embedding_flow.py
2026-04-21 20:16:33 -06:00

267 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# embedding_flow.py
import os
import sys
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
# in the environment still wins (setdefault is a no-op if the key exists).
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
os.environ.setdefault("DO_NOT_TRACK", "1")
# Pin per-process thread pools to 1 so Ray's worker parallelism doesn't
# multiply against BLAS/numba/etc. thread pools — otherwise 4 workers × N
# cores → thrash. Must be set before numpy/numba/sklearn import, since
# those libs latch onto these env vars at import time. Ray manages OMP
# per-task-CPU but does NOT manage NUMBA_NUM_THREADS, which is what
# PaCMAP/UMAP use for their optimization loops.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("NUMBA_NUM_THREADS", "1")
from datetime import timedelta
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
from prefect import flow, task
from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner
import pandas as pd
from sklearn.preprocessing import StandardScaler
import embedding_utils as E
from joblib import cpu_count
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
def generate_initial_frame_task(
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
) -> pd.DataFrame:
"""
Generate the initial data frame using a specified data generator.
Parameters:
- generator_path: str
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
- generator_kwargs: Dict[str, Any]
Keyword arguments to pass to the data generator function.
- id_column: str
Column name to use as a unique identifier.
Returns:
- df: pd.DataFrame
DataFrame with generated data and unique IDs.
"""
generator_func = E.dynamic_import(generator_path)
data, labels = generator_func(**generator_kwargs)
# Per-feature z-score so jitter_scale has consistent meaning across
# generators and reducers see comparably-scaled inputs.
data = StandardScaler().fit_transform(data)
df = pd.DataFrame(
{
"feature_0": data[:, 0],
"feature_1": data[:, 1],
"feature_2": data[:, 2],
id_column: range(data.shape[0]),
"time": 0,
}
)
df[id_column] = df[id_column].astype(int)
return df
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
def generate_snapshots_task(
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
) -> List[pd.DataFrame]:
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
@task(
cache_policy=INPUTS,
cache_expiration=timedelta(days=1),
task_run_name="embed-{time_idx}",
)
def create_embedding(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedder: str,
embed_args: Dict[str, Any],
time_idx: str | int,
id_column: str = "id",
) -> pd.DataFrame:
return E.create_embedding_dataframe(
snapshot=snapshot,
embed_columns=embed_columns,
embedding_algorithm_str=embedder,
embedding_kwargs=embed_args,
id_column=id_column,
time_idx=time_idx,
)
@task
def collect_data_task(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
return E.collect_and_prepare_for_plotly(
embedded_dfs, sort_time=sort_time, id_column=id_column
)
@task(
task_run_name="plot-{output_path}",
retries=3,
cache_policy=NO_CACHE,
)
def plot_and_save_task(
combined_df: pd.DataFrame,
title: str,
output_path: str,
frame_duration: int = 500,
transition_duration: int = 500,
fixed_axes: bool = True,
equal_aspect: bool = True,
samples: int = 25_000,
):
fig = E.plot_embedding_over_time(
combined_df,
title=title,
frame_duration=int(frame_duration),
transition_duration=int(transition_duration),
fixed_axes=fixed_axes,
equal_aspect=equal_aspect,
samples=samples,
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
return output_path
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
def embedding_flow(
num_points: int = 5000,
num_timesteps: int = 48,
jitter_scale: float = 0.01,
seed: int = 42,
generator_path: str = "sklearn.datasets.make_s_curve",
generator_kwargs: Optional[Dict[str, Any]] = None,
embed_columns: Optional[List[str]] = None,
embedder: str = "sklearn.decomposition.FactorAnalysis",
embed_args: Optional[Dict[str, Any]] = None,
output_dir: str = "figs",
id_column: str = "id",
frame_duration: int = 1200,
transition_duration: int = 2400,
reference_speedup: float = 10.0,
samples: int = 10_000,
):
generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS,
**(generator_kwargs or {}),
"n_samples": num_points,
}
embed_columns = (
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
)
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
Path(output_dir).mkdir(parents=True, exist_ok=True)
_generator = generator_path.split(".")[-1]
output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
merged_embed_args = embed_args
# Generate initial frame using the specified data generator
initial_frame = generate_initial_frame_task.submit(
generator_path=generator_path,
generator_kwargs=generator_kwargs,
id_column=id_column,
)
# Generate snapshots
snapshots = generate_snapshots_task.submit(
initial_df=initial_frame.result(),
num_timesteps=num_timesteps,
jitter_scale=jitter_scale,
seed=seed,
)
snapshot_list = snapshots.result()
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
dates = [
f"{year}-{month:02d}-01"
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
for month in range(1, 13)
][:num_timesteps]
# Apply embeddings in parallel using Prefect's mapping
embeddings = create_embedding.map(
snapshot=snapshot_list,
time_idx=dates,
embed_columns=[embed_columns] * num_timesteps,
embedder=[embedder] * num_timesteps,
embed_args=[merged_embed_args] * num_timesteps,
id_column=[id_column] * num_timesteps,
)
# Collect all embeddings
combined_df = collect_data_task.submit(
embedded_dfs=embeddings.result(), sort_time=False
).result()
# make the original snapshots look like the embeddings
dfr = collect_data_task.submit(
embedded_dfs=snapshot_list, sort_time=False
).result()
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
dfr.columns = ["x", "y", id_column, "time"]
dfr["time"] = combined_df["time"].to_numpy()
# Plot reference animation
ref_path = plot_and_save_task.submit(
combined_df=dfr,
title=title_ref,
output_path=output_ref,
frame_duration=max(frame_duration / reference_speedup, 175),
transition_duration=max(transition_duration / reference_speedup, 350),
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Plot embedding animation
emb_path = plot_and_save_task.submit(
combined_df=combined_df,
title=title_embed,
output_path=output_embed,
frame_duration=frame_duration,
transition_duration=transition_duration,
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
return (ref_path.result(), emb_path.result())
if __name__ == "__main__":
embedding_flow.serve()
# embedding_flow()