dr-sandbox/flows/embedding_utils.py
2026-04-21 18:02:39 -06:00

492 lines
17 KiB
Python

# embedding_utils.py
import importlib
from typing import List, Optional, Type, Union
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.graph_objects import Figure
def dynamic_import(class_path: str) -> Type:
"""
Dynamically import a class from a given module path.
Parameters:
- class_path: str
The full path to the class (e.g., 'sklearn.decomposition.PCA').
Returns:
- cls: Type
The imported class.
Raises:
- ImportError: If the module or class cannot be found.
"""
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return cls
except (ImportError, AttributeError) as e:
raise ImportError(f"Cannot import '{class_path}'. Error: {e}")
def create_embedding_dataframe(
snapshot: pd.DataFrame,
embed_columns: List[str],
embedding_algorithm_str: str = "sklearn.decomposition.PCA",
embedding_kwargs: Optional[dict] = None,
label_columns: Optional[List[str]] = None,
id_column: Optional[str] = None,
time_idx: Optional[Union[int, str]] = None,
) -> pd.DataFrame:
"""
Apply an embedding algorithm to a single snapshot and prepare the DataFrame.
Parameters:
- snapshot: pd.DataFrame
The input data snapshot to embed.
- time_idx: Optional[Union[int, str]]
The time identifier for the snapshot (e.g., integer index or 'YYYYMMDD' string).
- embedding_algorithm_str: str
The full module path to the embedding class (e.g., 'sklearn.decomposition.PCA').
- embedding_kwargs: Optional[dict]
Additional keyword arguments for the embedding algorithm.
- label_columns: Optional[List[str]]
List of column names to include in the tooltip labels. If None or empty, labels are empty.
- id_column: Optional[str]
Column name to use as a unique identifier. If None, a default integer ID is assigned.
Returns:
- embedded_df: pd.DataFrame
DataFrame containing 'id', 'x', 'y', 'time', and 'label' columns.
"""
if embedding_kwargs is None:
embedding_kwargs = {}
if label_columns is None:
label_columns = []
# Assign unique ID
embedded_df = pd.DataFrame()
if id_column and id_column in snapshot.columns:
embedded_df["id"] = snapshot[id_column] # .astype(int) # Ensure ID is integer
else:
embedded_df["id"] = snapshot.index # .astype(int) # Default to integer index
# Dynamically import the embedding class
embedding_class = dynamic_import(embedding_algorithm_str)
# Initialize and fit the embedding model
model = embedding_class(**embedding_kwargs)
embedded = model.fit_transform(snapshot[embed_columns].values)
if embedded.shape[1] != 2:
raise ValueError("Embedding must result in 2 dimensions.")
embedded_coords = pd.DataFrame(embedded, columns=["x", "y"])
embedded_df = pd.concat([embedded_df, embedded_coords], axis=1)
if time_idx is not None:
embedded_df["time"] = time_idx
else: # if not supplied, use "time" from snapshot
embedded_df["time"] = snapshot["time"]
# Create tooltip labels
if label_columns:
# Ensure the label columns exist
missing_cols = [col for col in label_columns if col not in snapshot.columns]
if missing_cols:
raise ValueError(f"Label columns not found in snapshot: {missing_cols}")
# Concatenate specified columns into a single string for the tooltip
labels = snapshot[label_columns].astype(str).agg(" | ".join, axis=1)
embedded_df["label"] = labels
else:
embedded_df["label"] = ""
for k in ["id", "x", "y", "time"]:
assert k in embedded_df.columns, k
return embedded_df
def collect_and_prepare_for_plotly(
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
) -> pd.DataFrame:
"""
Combine multiple embedded DataFrames and prepare them for Plotly visualization.
Parameters:
- embedded_dfs: List[pd.DataFrame]
A list of DataFrames, each containing 'id', 'x', 'y', 'time', and 'label' columns.
- sort_time: bool
Whether to sort the combined DataFrame by the 'time' column and then by 'id'.
Returns:
- combined_df: pd.DataFrame
A single DataFrame concatenating all embedded snapshots, sorted by time and id if specified.
"""
if not embedded_dfs:
raise ValueError("The list of embedded DataFrames is empty.")
# Concatenate all embedded DataFrames
combined_df = pd.concat(embedded_dfs, ignore_index=True)
if "id" not in combined_df.columns:
if id_column in combined_df.columns:
# rename column to 'id'
combined_df.rename(columns={id_column: "id"}, inplace=True)
else:
raise ValueError(
"Each embedded DataFrame must contain an 'id' column for sorting."
)
# Sort by 'time' and 'id' if required
if sort_time:
# Determine if 'time' is numeric or string for appropriate sorting
# if pd.api.types.is_numeric_dtype(combined_df["time"]):
# combined_df = combined_df.sort_values(by=["time", "id"])
# else:
# Assume string dates are sortable (e.g., 'YYYYMMDD')
combined_df = combined_df.sort_values(by=["time", "id"])
# Reset index after sorting
combined_df.reset_index(drop=True, inplace=True)
return combined_df
def plot_embedding_over_time(
combined_df: pd.DataFrame,
title: str = "Embedding Over Time",
color_column: Optional[str] = None,
fixed_axes: bool = True,
equal_aspect: bool = True,
frame_duration: int = 500,
transition_duration: int = 500,
samples: int = 0,
) -> Figure:
"""
Create an interactive Plotly scatter plot with animation over time.
Parameters:
- combined_df: pd.DataFrame
DataFrame containing at least 'id', 'time', and numerical feature columns.
- title: str
Title of the plot.
- color_column: Optional[str]
Column name for color encoding. If None, no color encoding is applied.
- fixed_axes: bool
If True, axes ranges are fixed across all frames for consistency.
- equal_aspect: bool
If True, the plot will have an equal aspect ratio.
- frame_duration: int
Duration of each animation frame in milliseconds.
- transition_duration: int
Duration of the transition between frames in milliseconds.
- samples: int (optional)
Number of samples to use for plotting (for faster rendering).
Returns:
- fig: plotly.graph_objs._figure.Figure
The Plotly figure object.
"""
# Step 1: Identify numerical columns excluding 'id' and 'time'
numeric_columns = combined_df.select_dtypes(
include=["float", "int", "bool"]
).columns.tolist()
numeric_columns = [col for col in numeric_columns if col not in ["id", "time"]]
if len(numeric_columns) < 2:
raise ValueError(
"DataFrame must have at least two numerical columns for x and y axes."
)
# Step 2: Use the first two numerical columns as default x and y
default_x = numeric_columns[0]
default_y = numeric_columns[1]
# Step 3: Sample the data if required
hover_data = (
[default_x, default_y]
if "label" in combined_df.columns
else [default_x, default_y]
)
hover_data = ["id"]
if samples > 0:
unique_ids = combined_df["id"].unique().tolist()
samples = min(samples, len(unique_ids))
sample_ids = np.random.choice(unique_ids, samples, replace=False)
combined_df_sample = combined_df[combined_df["id"].isin(sample_ids)]
else:
combined_df_sample = combined_df
# Step 4: Determine opacity based on number of unique IDs
opacity = max(0.1, min(5000.0 / combined_df_sample["id"].nunique(), 1))
# Step 5: Create the initial scatter plot using Plotly Express
if color_column and color_column in combined_df.columns:
fig = px.scatter(
combined_df_sample,
x=default_x,
y=default_y,
animation_frame="time",
animation_group="id",
color=color_column,
hover_data=hover_data,
title=title,
labels={default_x: "x", default_y: "y", "time": "Time"},
category_orders={"time": sorted(combined_df["time"].unique())},
opacity=opacity,
)
else:
fig = px.scatter(
combined_df_sample,
x=default_x,
y=default_y,
animation_frame="time",
animation_group="id",
hover_data=hover_data,
title=title,
labels={default_x: "x", default_y: "y", "time": "Time"},
category_orders={"time": sorted(combined_df["time"].unique())},
opacity=opacity,
)
# Step 6: Fix axes ranges if required
if fixed_axes:
x_min, x_max = (
combined_df_sample[default_x].min(),
combined_df_sample[default_x].max(),
)
y_min, y_max = (
combined_df_sample[default_y].min(),
combined_df_sample[default_y].max(),
)
fig.update_layout(
xaxis=dict(range=[x_min, x_max]),
yaxis=dict(range=[y_min, y_max]),
)
# Step 7: Enforce equal aspect ratio if required
if equal_aspect:
fig.update_yaxes(scaleanchor="x", scaleratio=1)
# Step 8: Prepare dropdowns if there are more than two numerical columns
if len(numeric_columns) > 2:
# Create dropdown options
dropdown_options = [
{"label": col.replace("_", " ").title(), "value": col}
for col in numeric_columns
]
# Dropdown for X-axis
dropdown_x = dict(
active=0,
buttons=[
dict(
label=option["label"],
method="update",
args=[
{"x": [combined_df_sample[option["value"]]]},
{
"xaxis.title.text": option["label"],
"hover_data": hover_data,
# 'hover_data': [option['value'], default_y] + hover_data
},
# rescale axis
{
"xaxis": {
"range": [
combined_df_sample[option["value"]].min(),
combined_df_sample[option["value"]].max(),
]
}
},
],
)
for option in dropdown_options
],
direction="down",
showactive=True,
x=0.4,
xanchor="left",
y=1.1,
yanchor="top",
pad={"r": 10, "t": 10},
name="X-Axis",
)
# Dropdown for Y-axis
dropdown_y = dict(
active=1,
buttons=[
dict(
label=option["label"],
method="update",
args=[
{"y": [combined_df_sample[option["value"]]]},
{
"yaxis.title.text": option["label"],
"hover_data": hover_data,
# 'hover_data': [default_x, option['value']] + hover_data
},
# rescale axis
{
"yaxis": {
"range": [
combined_df_sample[option["value"]].min(),
combined_df_sample[option["value"]].max(),
]
}
},
],
)
for option in dropdown_options
],
direction="down",
showactive=True,
x=0.4,
xanchor="left",
y=1.2,
yanchor="top",
pad={"r": 10, "t": 10},
name="Y-Axis",
)
# Step 9: Consolidate all layout updates in a single call
fig.update_layout(
# updatemenus=[
# {},
# dropdown_y,
# dropdown_x,
# ],
xaxis_title=default_x.replace("_", " ").title(),
yaxis_title=default_y.replace("_", " ").title(),
width=800,
height=800,
margin=dict(t=100, b=150),
)
else:
# If only two numerical columns, set titles accordingly
fig.update_layout(
xaxis_title=default_x.replace("_", " ").title(),
yaxis_title=default_y.replace("_", " ").title(),
width=800,
height=800,
)
# # Step 10: Adjust animation durations for smoother transitions
if fig.layout.updatemenus:
for updatemenu in fig.layout.updatemenus:
if "buttons" in updatemenu:
for btn in updatemenu.buttons:
if (
"args" in btn
and len(btn.args) > 1
and isinstance(btn.args[1], dict)
):
frame = btn.args[1].get("frame", {})
transition = btn.args[1].get("transition", {})
frame["duration"] = frame_duration
transition["duration"] = transition_duration
btn.args[1]["frame"] = frame
btn.args[1]["transition"] = transition
return fig
def generate_initial_frame(
num_points: int, num_features: int, seed: int = 42, id_prefix: str = "Point"
) -> pd.DataFrame:
"""
Generate an initial frame with random points and unique IDs.
Parameters:
- num_points: int
Number of data points.
- num_features: int
Number of features per data point.
- seed: int
Random seed for reproducibility.
- id_prefix: str
Prefix for generating unique IDs.
Returns:
- df: pd.DataFrame
DataFrame with random data and unique IDs.
"""
np.random.seed(seed)
data = np.random.randn(num_points, num_features)
df = pd.DataFrame(data, columns=[f"feature_{j}" for j in range(num_features)])
df["id"] = [i for i in range(num_points)]
df["id"] = df["id"].astype(int)
df["time"] = 0
return df
def generate_jittered_snapshots(
initial_df: pd.DataFrame,
num_snapshots: int,
jitter_scale: float = 0.1,
seed: int = 42,
) -> List[pd.DataFrame]:
"""
Generate snapshots by applying random jitter to the initial frame and randomly adding/removing points.
Parameters:
- initial_df: pd.DataFrame
The initial DataFrame to apply jitter.
- num_snapshots: int
Number of snapshots to generate.
- jitter_scale: float
Standard deviation of the Gaussian noise added for jitter.
- seed: int
Random seed for reproducibility.
Returns:
- snapshots: List[pd.DataFrame]
List of jittered DataFrames with dynamic point introduction/removal.
"""
np.random.seed(seed)
snapshots = []
current_df = initial_df.copy()
for i in range(num_snapshots):
# Apply jitter (set to 0 for testing)
jitter = np.random.normal(
loc=0.0,
scale=jitter_scale,
size=(current_df.shape[0], current_df.shape[1] - 2),
)
jittered_features = current_df.iloc[:, :-2] + jitter # Exclude 'id' and 'time'
jittered_df = jittered_features.copy()
jittered_df["id"] = current_df["id"]
# Randomly decide to add or remove points
action = np.random.choice(["add", "remove", "none"], p=[0.5, 0.5, 0])
if action == "add":
# Add a new point with a unique integer ID
new_point = np.random.randn(1, current_df.shape[1] - 2)
new_id = current_df["id"].max() + 1
new_df = pd.DataFrame(
new_point,
columns=[f"feature_{j}" for j in range(current_df.shape[1] - 2)],
)
new_df["id"] = new_id
jittered_df = pd.concat([jittered_df, new_df], ignore_index=True)
elif action == "remove" and len(jittered_df) > 1:
# Remove a random point
remove_idx = np.random.choice(jittered_df.index)
jittered_df = jittered_df.drop(index=remove_idx).reset_index(drop=True)
# Assign time index
jittered_df["time"] = i + 1 # Start from 1
snapshots.append(jittered_df)
# Update current_df for next iteration
current_df = jittered_df.copy()
return snapshots