dr-sandbox/flows/embedding_flow.py
2026-04-21 19:55:01 -06:00

255 lines
7.9 KiB
Python

# embedding_flow.py
import os
import sys
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
# in the environment still wins (setdefault is a no-op if the key exists).
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
os.environ.setdefault("DO_NOT_TRACK", "1")
from datetime import timedelta
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
from prefect import flow, task
from prefect.cache_policies import INPUTS, NO_CACHE
from prefect_ray import RayTaskRunner
import pandas as pd
from sklearn.preprocessing import StandardScaler
import embedding_utils as E
from joblib import cpu_count
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
def generate_initial_frame_task(
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
) -> pd.DataFrame:
"""
Generate the initial data frame using a specified data generator.
Parameters:
- generator_path: str
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
- generator_kwargs: Dict[str, Any]
Keyword arguments to pass to the data generator function.
- id_column: str
Column name to use as a unique identifier.
Returns:
- df: pd.DataFrame
DataFrame with generated data and unique IDs.
"""
generator_func = E.dynamic_import(generator_path)
data, labels = generator_func(**generator_kwargs)
# Per-feature z-score so jitter_scale has consistent meaning across
# generators and reducers see comparably-scaled inputs.
data = StandardScaler().fit_transform(data)
df = pd.DataFrame(
{
"feature_0": data[:, 0],
"feature_1": data[:, 1],
"feature_2": data[:, 2],
id_column: range(data.shape[0]),
"time": 0,
}
)
df[id_column] = df[id_column].astype(int)
return df
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
def generate_snapshots_task(
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
) -> List[pd.DataFrame]:
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
@task(
cache_policy=INPUTS,
cache_expiration=timedelta(days=1),
task_run_name="embed-{time_idx}",
)
def create_embedding(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedder: str,
embed_args: Dict[str, Any],
time_idx: str | int,
id_column: str = "id",
) -> pd.DataFrame:
return E.create_embedding_dataframe(
snapshot=snapshot,
embed_columns=embed_columns,
embedding_algorithm_str=embedder,
embedding_kwargs=embed_args,
id_column=id_column,
time_idx=time_idx,
)
@task
def collect_data_task(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
return E.collect_and_prepare_for_plotly(
embedded_dfs, sort_time=sort_time, id_column=id_column
)
@task(
task_run_name="plot-{output_path}",
retries=3,
cache_policy=NO_CACHE,
)
def plot_and_save_task(
combined_df: pd.DataFrame,
title: str,
output_path: str,
frame_duration: int = 500,
transition_duration: int = 500,
fixed_axes: bool = True,
equal_aspect: bool = True,
samples: int = 25_000,
):
fig = E.plot_embedding_over_time(
combined_df,
title=title,
frame_duration=int(frame_duration),
transition_duration=int(transition_duration),
fixed_axes=fixed_axes,
equal_aspect=equal_aspect,
samples=samples,
)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
return output_path
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
def embedding_flow(
num_points: int = 5000,
num_timesteps: int = 48,
jitter_scale: float = 0.01,
seed: int = 42,
generator_path: str = "sklearn.datasets.make_s_curve",
generator_kwargs: Optional[Dict[str, Any]] = None,
embed_columns: Optional[List[str]] = None,
embedder: str = "sklearn.decomposition.FactorAnalysis",
embed_args: Optional[Dict[str, Any]] = None,
output_dir: str = "figs",
id_column: str = "id",
frame_duration: int = 1200,
transition_duration: int = 2400,
reference_speedup: float = 10.0,
samples: int = 10_000,
):
generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS,
**(generator_kwargs or {}),
"n_samples": num_points,
}
embed_columns = (
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
)
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
Path(output_dir).mkdir(parents=True, exist_ok=True)
_generator = generator_path.split(".")[-1]
output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
)
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
merged_embed_args = embed_args
# Generate initial frame using the specified data generator
initial_frame = generate_initial_frame_task.submit(
generator_path=generator_path,
generator_kwargs=generator_kwargs,
id_column=id_column,
)
# Generate snapshots
snapshots = generate_snapshots_task.submit(
initial_df=initial_frame.result(),
num_timesteps=num_timesteps,
jitter_scale=jitter_scale,
seed=seed,
)
snapshot_list = snapshots.result()
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
dates = [
f"{year}-{month:02d}-01"
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
for month in range(1, 13)
][:num_timesteps]
# Apply embeddings in parallel using Prefect's mapping
embeddings = create_embedding.map(
snapshot=snapshot_list,
time_idx=dates,
embed_columns=[embed_columns] * num_timesteps,
embedder=[embedder] * num_timesteps,
embed_args=[merged_embed_args] * num_timesteps,
id_column=[id_column] * num_timesteps,
)
# Collect all embeddings
combined_df = collect_data_task.submit(
embedded_dfs=embeddings.result(), sort_time=False
).result()
# make the original snapshots look like the embeddings
dfr = collect_data_task.submit(
embedded_dfs=snapshot_list, sort_time=False
).result()
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
dfr.columns = ["x", "y", id_column, "time"]
dfr["time"] = combined_df["time"].to_numpy()
# Plot reference animation
ref_path = plot_and_save_task.submit(
combined_df=dfr,
title=title_ref,
output_path=output_ref,
frame_duration=max(frame_duration / reference_speedup, 175),
transition_duration=max(transition_duration / reference_speedup, 350),
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
# Plot embedding animation
emb_path = plot_and_save_task.submit(
combined_df=combined_df,
title=title_embed,
output_path=output_embed,
frame_duration=frame_duration,
transition_duration=transition_duration,
fixed_axes=True,
equal_aspect=False,
samples=samples,
)
return (ref_path.result(), emb_path.result())
if __name__ == "__main__":
embedding_flow.serve()
# embedding_flow()