255 lines
7.9 KiB
Python
255 lines
7.9 KiB
Python
# embedding_flow.py
|
|
import os
|
|
import sys
|
|
|
|
# Default to the local Docker Prefect server. An explicit PREFECT_API_URL
|
|
# in the environment still wins (setdefault is a no-op if the key exists).
|
|
os.environ.setdefault("PREFECT_API_URL", "http://localhost:4200/api")
|
|
os.environ.setdefault("DO_NOT_TRACK", "1")
|
|
|
|
from datetime import timedelta
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from prefect import flow, task
|
|
from prefect.cache_policies import INPUTS, NO_CACHE
|
|
from prefect_ray import RayTaskRunner
|
|
|
|
import pandas as pd
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
import embedding_utils as E
|
|
from joblib import cpu_count
|
|
|
|
|
|
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=1))
|
|
def generate_initial_frame_task(
|
|
generator_path: str, generator_kwargs: Dict[str, Any], id_column: str = "id"
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Generate the initial data frame using a specified data generator.
|
|
|
|
Parameters:
|
|
- generator_path: str
|
|
The full module path to the data generator function (e.g., 'sklearn.datasets.make_s_curve').
|
|
- generator_kwargs: Dict[str, Any]
|
|
Keyword arguments to pass to the data generator function.
|
|
- id_column: str
|
|
Column name to use as a unique identifier.
|
|
|
|
Returns:
|
|
- df: pd.DataFrame
|
|
DataFrame with generated data and unique IDs.
|
|
"""
|
|
generator_func = E.dynamic_import(generator_path)
|
|
data, labels = generator_func(**generator_kwargs)
|
|
|
|
# Per-feature z-score so jitter_scale has consistent meaning across
|
|
# generators and reducers see comparably-scaled inputs.
|
|
data = StandardScaler().fit_transform(data)
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"feature_0": data[:, 0],
|
|
"feature_1": data[:, 1],
|
|
"feature_2": data[:, 2],
|
|
id_column: range(data.shape[0]),
|
|
"time": 0,
|
|
}
|
|
)
|
|
df[id_column] = df[id_column].astype(int)
|
|
return df
|
|
|
|
|
|
@task(cache_policy=INPUTS, cache_expiration=timedelta(hours=12))
|
|
def generate_snapshots_task(
|
|
initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float, seed: int = 42
|
|
) -> List[pd.DataFrame]:
|
|
return E.generate_jittered_snapshots(initial_df, num_timesteps, jitter_scale, seed)
|
|
|
|
|
|
@task(
|
|
cache_policy=INPUTS,
|
|
cache_expiration=timedelta(days=1),
|
|
task_run_name="embed-{time_idx}",
|
|
)
|
|
def create_embedding(
|
|
snapshot: pd.DataFrame,
|
|
embed_columns: List[str],
|
|
embedder: str,
|
|
embed_args: Dict[str, Any],
|
|
time_idx: str | int,
|
|
id_column: str = "id",
|
|
) -> pd.DataFrame:
|
|
return E.create_embedding_dataframe(
|
|
snapshot=snapshot,
|
|
embed_columns=embed_columns,
|
|
embedding_algorithm_str=embedder,
|
|
embedding_kwargs=embed_args,
|
|
id_column=id_column,
|
|
time_idx=time_idx,
|
|
)
|
|
|
|
|
|
@task
|
|
def collect_data_task(
|
|
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
|
|
) -> pd.DataFrame:
|
|
return E.collect_and_prepare_for_plotly(
|
|
embedded_dfs, sort_time=sort_time, id_column=id_column
|
|
)
|
|
|
|
|
|
@task(
|
|
task_run_name="plot-{output_path}",
|
|
retries=3,
|
|
cache_policy=NO_CACHE,
|
|
)
|
|
def plot_and_save_task(
|
|
combined_df: pd.DataFrame,
|
|
title: str,
|
|
output_path: str,
|
|
frame_duration: int = 500,
|
|
transition_duration: int = 500,
|
|
fixed_axes: bool = True,
|
|
equal_aspect: bool = True,
|
|
samples: int = 25_000,
|
|
):
|
|
fig = E.plot_embedding_over_time(
|
|
combined_df,
|
|
title=title,
|
|
frame_duration=int(frame_duration),
|
|
transition_duration=int(transition_duration),
|
|
fixed_axes=fixed_axes,
|
|
equal_aspect=equal_aspect,
|
|
samples=samples,
|
|
)
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
fig.write_html(output_path)
|
|
return output_path
|
|
|
|
|
|
_DEFAULT_GENERATOR_KWARGS: Dict[str, Any] = {"random_state": 0}
|
|
_DEFAULT_EMBED_COLUMNS: List[str] = ["feature_0", "feature_2", "feature_1"]
|
|
_DEFAULT_EMBED_ARGS: Dict[str, Any] = {"n_components": 2, "random_state": 30}
|
|
|
|
|
|
@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
|
|
def embedding_flow(
|
|
num_points: int = 5000,
|
|
num_timesteps: int = 48,
|
|
jitter_scale: float = 0.01,
|
|
seed: int = 42,
|
|
generator_path: str = "sklearn.datasets.make_s_curve",
|
|
generator_kwargs: Optional[Dict[str, Any]] = None,
|
|
embed_columns: Optional[List[str]] = None,
|
|
embedder: str = "sklearn.decomposition.FactorAnalysis",
|
|
embed_args: Optional[Dict[str, Any]] = None,
|
|
output_dir: str = "figs",
|
|
id_column: str = "id",
|
|
frame_duration: int = 1200,
|
|
transition_duration: int = 2400,
|
|
reference_speedup: float = 10.0,
|
|
samples: int = 10_000,
|
|
):
|
|
generator_kwargs = {
|
|
**_DEFAULT_GENERATOR_KWARGS,
|
|
**(generator_kwargs or {}),
|
|
"n_samples": num_points,
|
|
}
|
|
embed_columns = (
|
|
list(embed_columns) if embed_columns is not None else list(_DEFAULT_EMBED_COLUMNS)
|
|
)
|
|
embed_args = dict(embed_args) if embed_args is not None else dict(_DEFAULT_EMBED_ARGS)
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
_generator = generator_path.split(".")[-1]
|
|
output_ref: str = (
|
|
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
|
)
|
|
output_embed: str = (
|
|
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
|
|
)
|
|
title_ref = f"Reference: {_generator}, N={num_points} with {jitter_scale} noise"
|
|
title_embed = f"Embedding: {embedder.split('.')[-1]} on {_generator}, N={num_points} with {jitter_scale} noise"
|
|
|
|
merged_embed_args = embed_args
|
|
|
|
# Generate initial frame using the specified data generator
|
|
initial_frame = generate_initial_frame_task.submit(
|
|
generator_path=generator_path,
|
|
generator_kwargs=generator_kwargs,
|
|
id_column=id_column,
|
|
)
|
|
|
|
# Generate snapshots
|
|
snapshots = generate_snapshots_task.submit(
|
|
initial_df=initial_frame.result(),
|
|
num_timesteps=num_timesteps,
|
|
jitter_scale=jitter_scale,
|
|
seed=seed,
|
|
)
|
|
snapshot_list = snapshots.result()
|
|
|
|
# One date per timestep (monthly, starting at 2000-01-01 for cosmetic reasons)
|
|
dates = [
|
|
f"{year}-{month:02d}-01"
|
|
for year in range(2000, 2001 + math.floor(num_timesteps / 12))
|
|
for month in range(1, 13)
|
|
][:num_timesteps]
|
|
|
|
# Apply embeddings in parallel using Prefect's mapping
|
|
embeddings = create_embedding.map(
|
|
snapshot=snapshot_list,
|
|
time_idx=dates,
|
|
embed_columns=[embed_columns] * num_timesteps,
|
|
embedder=[embedder] * num_timesteps,
|
|
embed_args=[merged_embed_args] * num_timesteps,
|
|
id_column=[id_column] * num_timesteps,
|
|
)
|
|
|
|
# Collect all embeddings
|
|
combined_df = collect_data_task.submit(
|
|
embedded_dfs=embeddings.result(), sort_time=False
|
|
).result()
|
|
|
|
# make the original snapshots look like the embeddings
|
|
dfr = collect_data_task.submit(
|
|
embedded_dfs=snapshot_list, sort_time=False
|
|
).result()
|
|
dfr = dfr[embed_columns[:2] + [id_column, "time"]]
|
|
dfr.columns = ["x", "y", id_column, "time"]
|
|
dfr["time"] = combined_df["time"].to_numpy()
|
|
|
|
# Plot reference animation
|
|
ref_path = plot_and_save_task.submit(
|
|
combined_df=dfr,
|
|
title=title_ref,
|
|
output_path=output_ref,
|
|
frame_duration=max(frame_duration / reference_speedup, 175),
|
|
transition_duration=max(transition_duration / reference_speedup, 350),
|
|
fixed_axes=True,
|
|
equal_aspect=False,
|
|
samples=samples,
|
|
)
|
|
|
|
# Plot embedding animation
|
|
emb_path = plot_and_save_task.submit(
|
|
combined_df=combined_df,
|
|
title=title_embed,
|
|
output_path=output_embed,
|
|
frame_duration=frame_duration,
|
|
transition_duration=transition_duration,
|
|
fixed_axes=True,
|
|
equal_aspect=False,
|
|
samples=samples,
|
|
)
|
|
|
|
return (ref_path.result(), emb_path.result())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
embedding_flow.serve()
|
|
# embedding_flow()
|