492 lines
17 KiB
Python
492 lines
17 KiB
Python
# embedding_utils.py
|
|
|
|
import importlib
|
|
from typing import List, Optional, Type, Union
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import plotly.express as px
|
|
from plotly.graph_objects import Figure
|
|
|
|
|
|
def dynamic_import(class_path: str) -> Type:
|
|
"""
|
|
Dynamically import a class from a given module path.
|
|
|
|
Parameters:
|
|
- class_path: str
|
|
The full path to the class (e.g., 'sklearn.decomposition.PCA').
|
|
|
|
Returns:
|
|
- cls: Type
|
|
The imported class.
|
|
|
|
Raises:
|
|
- ImportError: If the module or class cannot be found.
|
|
"""
|
|
try:
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|
module = importlib.import_module(module_path)
|
|
cls = getattr(module, class_name)
|
|
return cls
|
|
except (ImportError, AttributeError) as e:
|
|
raise ImportError(f"Cannot import '{class_path}'. Error: {e}")
|
|
|
|
|
|
def create_embedding_dataframe(
|
|
snapshot: pd.DataFrame,
|
|
embed_columns: List[str],
|
|
embedding_algorithm_str: str = "sklearn.decomposition.PCA",
|
|
embedding_kwargs: Optional[dict] = None,
|
|
label_columns: Optional[List[str]] = None,
|
|
id_column: Optional[str] = None,
|
|
time_idx: Optional[Union[int, str]] = None,
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Apply an embedding algorithm to a single snapshot and prepare the DataFrame.
|
|
|
|
Parameters:
|
|
- snapshot: pd.DataFrame
|
|
The input data snapshot to embed.
|
|
- time_idx: Optional[Union[int, str]]
|
|
The time identifier for the snapshot (e.g., integer index or 'YYYYMMDD' string).
|
|
- embedding_algorithm_str: str
|
|
The full module path to the embedding class (e.g., 'sklearn.decomposition.PCA').
|
|
- embedding_kwargs: Optional[dict]
|
|
Additional keyword arguments for the embedding algorithm.
|
|
- label_columns: Optional[List[str]]
|
|
List of column names to include in the tooltip labels. If None or empty, labels are empty.
|
|
- id_column: Optional[str]
|
|
Column name to use as a unique identifier. If None, a default integer ID is assigned.
|
|
|
|
Returns:
|
|
- embedded_df: pd.DataFrame
|
|
DataFrame containing 'id', 'x', 'y', 'time', and 'label' columns.
|
|
"""
|
|
if embedding_kwargs is None:
|
|
embedding_kwargs = {}
|
|
if label_columns is None:
|
|
label_columns = []
|
|
|
|
# Assign unique ID
|
|
embedded_df = pd.DataFrame()
|
|
if id_column and id_column in snapshot.columns:
|
|
embedded_df["id"] = snapshot[id_column] # .astype(int) # Ensure ID is integer
|
|
else:
|
|
embedded_df["id"] = snapshot.index # .astype(int) # Default to integer index
|
|
|
|
# Dynamically import the embedding class
|
|
embedding_class = dynamic_import(embedding_algorithm_str)
|
|
|
|
# Initialize and fit the embedding model
|
|
model = embedding_class(**embedding_kwargs)
|
|
embedded = model.fit_transform(snapshot[embed_columns].values)
|
|
|
|
if embedded.shape[1] != 2:
|
|
raise ValueError("Embedding must result in 2 dimensions.")
|
|
|
|
embedded_coords = pd.DataFrame(embedded, columns=["x", "y"])
|
|
embedded_df = pd.concat([embedded_df, embedded_coords], axis=1)
|
|
if time_idx is not None:
|
|
embedded_df["time"] = time_idx
|
|
else: # if not supplied, use "time" from snapshot
|
|
embedded_df["time"] = snapshot["time"]
|
|
|
|
# Create tooltip labels
|
|
if label_columns:
|
|
# Ensure the label columns exist
|
|
missing_cols = [col for col in label_columns if col not in snapshot.columns]
|
|
if missing_cols:
|
|
raise ValueError(f"Label columns not found in snapshot: {missing_cols}")
|
|
# Concatenate specified columns into a single string for the tooltip
|
|
labels = snapshot[label_columns].astype(str).agg(" | ".join, axis=1)
|
|
embedded_df["label"] = labels
|
|
else:
|
|
embedded_df["label"] = ""
|
|
|
|
for k in ["id", "x", "y", "time"]:
|
|
assert k in embedded_df.columns, k
|
|
return embedded_df
|
|
|
|
|
|
def collect_and_prepare_for_plotly(
|
|
embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id"
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Combine multiple embedded DataFrames and prepare them for Plotly visualization.
|
|
|
|
Parameters:
|
|
- embedded_dfs: List[pd.DataFrame]
|
|
A list of DataFrames, each containing 'id', 'x', 'y', 'time', and 'label' columns.
|
|
- sort_time: bool
|
|
Whether to sort the combined DataFrame by the 'time' column and then by 'id'.
|
|
|
|
Returns:
|
|
- combined_df: pd.DataFrame
|
|
A single DataFrame concatenating all embedded snapshots, sorted by time and id if specified.
|
|
"""
|
|
if not embedded_dfs:
|
|
raise ValueError("The list of embedded DataFrames is empty.")
|
|
|
|
# Concatenate all embedded DataFrames
|
|
combined_df = pd.concat(embedded_dfs, ignore_index=True)
|
|
|
|
if "id" not in combined_df.columns:
|
|
if id_column in combined_df.columns:
|
|
# rename column to 'id'
|
|
combined_df.rename(columns={id_column: "id"}, inplace=True)
|
|
else:
|
|
raise ValueError(
|
|
"Each embedded DataFrame must contain an 'id' column for sorting."
|
|
)
|
|
# Sort by 'time' and 'id' if required
|
|
if sort_time:
|
|
# Determine if 'time' is numeric or string for appropriate sorting
|
|
# if pd.api.types.is_numeric_dtype(combined_df["time"]):
|
|
# combined_df = combined_df.sort_values(by=["time", "id"])
|
|
# else:
|
|
# Assume string dates are sortable (e.g., 'YYYYMMDD')
|
|
combined_df = combined_df.sort_values(by=["time", "id"])
|
|
|
|
# Reset index after sorting
|
|
combined_df.reset_index(drop=True, inplace=True)
|
|
return combined_df
|
|
|
|
|
|
def plot_embedding_over_time(
|
|
combined_df: pd.DataFrame,
|
|
title: str = "Embedding Over Time",
|
|
color_column: Optional[str] = None,
|
|
fixed_axes: bool = True,
|
|
equal_aspect: bool = True,
|
|
frame_duration: int = 500,
|
|
transition_duration: int = 500,
|
|
samples: int = 0,
|
|
) -> Figure:
|
|
"""
|
|
Create an interactive Plotly scatter plot with animation over time.
|
|
|
|
Parameters:
|
|
- combined_df: pd.DataFrame
|
|
DataFrame containing at least 'id', 'time', and numerical feature columns.
|
|
- title: str
|
|
Title of the plot.
|
|
- color_column: Optional[str]
|
|
Column name for color encoding. If None, no color encoding is applied.
|
|
- fixed_axes: bool
|
|
If True, axes ranges are fixed across all frames for consistency.
|
|
- equal_aspect: bool
|
|
If True, the plot will have an equal aspect ratio.
|
|
- frame_duration: int
|
|
Duration of each animation frame in milliseconds.
|
|
- transition_duration: int
|
|
Duration of the transition between frames in milliseconds.
|
|
- samples: int (optional)
|
|
Number of samples to use for plotting (for faster rendering).
|
|
Returns:
|
|
- fig: plotly.graph_objs._figure.Figure
|
|
The Plotly figure object.
|
|
"""
|
|
|
|
# Step 1: Identify numerical columns excluding 'id' and 'time'
|
|
numeric_columns = combined_df.select_dtypes(
|
|
include=["float", "int", "bool"]
|
|
).columns.tolist()
|
|
numeric_columns = [col for col in numeric_columns if col not in ["id", "time"]]
|
|
|
|
if len(numeric_columns) < 2:
|
|
raise ValueError(
|
|
"DataFrame must have at least two numerical columns for x and y axes."
|
|
)
|
|
|
|
# Step 2: Use the first two numerical columns as default x and y
|
|
default_x = numeric_columns[0]
|
|
default_y = numeric_columns[1]
|
|
|
|
# Step 3: Sample the data if required
|
|
hover_data = (
|
|
[default_x, default_y]
|
|
if "label" in combined_df.columns
|
|
else [default_x, default_y]
|
|
)
|
|
hover_data = ["id"]
|
|
|
|
if samples > 0:
|
|
unique_ids = combined_df["id"].unique().tolist()
|
|
samples = min(samples, len(unique_ids))
|
|
sample_ids = np.random.choice(unique_ids, samples, replace=False)
|
|
combined_df_sample = combined_df[combined_df["id"].isin(sample_ids)]
|
|
else:
|
|
combined_df_sample = combined_df
|
|
|
|
# Step 4: Determine opacity based on number of unique IDs
|
|
opacity = max(0.1, min(5000.0 / combined_df_sample["id"].nunique(), 1))
|
|
|
|
# Step 5: Create the initial scatter plot using Plotly Express
|
|
if color_column and color_column in combined_df.columns:
|
|
fig = px.scatter(
|
|
combined_df_sample,
|
|
x=default_x,
|
|
y=default_y,
|
|
animation_frame="time",
|
|
animation_group="id",
|
|
color=color_column,
|
|
hover_data=hover_data,
|
|
title=title,
|
|
labels={default_x: "x", default_y: "y", "time": "Time"},
|
|
category_orders={"time": sorted(combined_df["time"].unique())},
|
|
opacity=opacity,
|
|
)
|
|
else:
|
|
fig = px.scatter(
|
|
combined_df_sample,
|
|
x=default_x,
|
|
y=default_y,
|
|
animation_frame="time",
|
|
animation_group="id",
|
|
hover_data=hover_data,
|
|
title=title,
|
|
labels={default_x: "x", default_y: "y", "time": "Time"},
|
|
category_orders={"time": sorted(combined_df["time"].unique())},
|
|
opacity=opacity,
|
|
)
|
|
|
|
# Step 6: Fix axes ranges if required
|
|
if fixed_axes:
|
|
x_min, x_max = (
|
|
combined_df_sample[default_x].min(),
|
|
combined_df_sample[default_x].max(),
|
|
)
|
|
y_min, y_max = (
|
|
combined_df_sample[default_y].min(),
|
|
combined_df_sample[default_y].max(),
|
|
)
|
|
fig.update_layout(
|
|
xaxis=dict(range=[x_min, x_max]),
|
|
yaxis=dict(range=[y_min, y_max]),
|
|
)
|
|
|
|
# Step 7: Enforce equal aspect ratio if required
|
|
if equal_aspect:
|
|
fig.update_yaxes(scaleanchor="x", scaleratio=1)
|
|
|
|
# Step 8: Prepare dropdowns if there are more than two numerical columns
|
|
if len(numeric_columns) > 2:
|
|
# Create dropdown options
|
|
dropdown_options = [
|
|
{"label": col.replace("_", " ").title(), "value": col}
|
|
for col in numeric_columns
|
|
]
|
|
|
|
# Dropdown for X-axis
|
|
dropdown_x = dict(
|
|
active=0,
|
|
buttons=[
|
|
dict(
|
|
label=option["label"],
|
|
method="update",
|
|
args=[
|
|
{"x": [combined_df_sample[option["value"]]]},
|
|
{
|
|
"xaxis.title.text": option["label"],
|
|
"hover_data": hover_data,
|
|
# 'hover_data': [option['value'], default_y] + hover_data
|
|
},
|
|
# rescale axis
|
|
{
|
|
"xaxis": {
|
|
"range": [
|
|
combined_df_sample[option["value"]].min(),
|
|
combined_df_sample[option["value"]].max(),
|
|
]
|
|
}
|
|
},
|
|
],
|
|
)
|
|
for option in dropdown_options
|
|
],
|
|
direction="down",
|
|
showactive=True,
|
|
x=0.4,
|
|
xanchor="left",
|
|
y=1.1,
|
|
yanchor="top",
|
|
pad={"r": 10, "t": 10},
|
|
name="X-Axis",
|
|
)
|
|
|
|
# Dropdown for Y-axis
|
|
dropdown_y = dict(
|
|
active=1,
|
|
buttons=[
|
|
dict(
|
|
label=option["label"],
|
|
method="update",
|
|
args=[
|
|
{"y": [combined_df_sample[option["value"]]]},
|
|
{
|
|
"yaxis.title.text": option["label"],
|
|
"hover_data": hover_data,
|
|
# 'hover_data': [default_x, option['value']] + hover_data
|
|
},
|
|
# rescale axis
|
|
{
|
|
"yaxis": {
|
|
"range": [
|
|
combined_df_sample[option["value"]].min(),
|
|
combined_df_sample[option["value"]].max(),
|
|
]
|
|
}
|
|
},
|
|
],
|
|
)
|
|
for option in dropdown_options
|
|
],
|
|
direction="down",
|
|
showactive=True,
|
|
x=0.4,
|
|
xanchor="left",
|
|
y=1.2,
|
|
yanchor="top",
|
|
pad={"r": 10, "t": 10},
|
|
name="Y-Axis",
|
|
)
|
|
|
|
# Step 9: Consolidate all layout updates in a single call
|
|
fig.update_layout(
|
|
# updatemenus=[
|
|
# {},
|
|
# dropdown_y,
|
|
# dropdown_x,
|
|
# ],
|
|
xaxis_title=default_x.replace("_", " ").title(),
|
|
yaxis_title=default_y.replace("_", " ").title(),
|
|
width=800,
|
|
height=800,
|
|
margin=dict(t=100, b=150),
|
|
)
|
|
else:
|
|
# If only two numerical columns, set titles accordingly
|
|
fig.update_layout(
|
|
xaxis_title=default_x.replace("_", " ").title(),
|
|
yaxis_title=default_y.replace("_", " ").title(),
|
|
width=800,
|
|
height=800,
|
|
)
|
|
|
|
# # Step 10: Adjust animation durations for smoother transitions
|
|
if fig.layout.updatemenus:
|
|
for updatemenu in fig.layout.updatemenus:
|
|
if "buttons" in updatemenu:
|
|
for btn in updatemenu.buttons:
|
|
if (
|
|
"args" in btn
|
|
and len(btn.args) > 1
|
|
and isinstance(btn.args[1], dict)
|
|
):
|
|
frame = btn.args[1].get("frame", {})
|
|
transition = btn.args[1].get("transition", {})
|
|
frame["duration"] = frame_duration
|
|
transition["duration"] = transition_duration
|
|
btn.args[1]["frame"] = frame
|
|
btn.args[1]["transition"] = transition
|
|
|
|
return fig
|
|
|
|
|
|
def generate_initial_frame(
|
|
num_points: int, num_features: int, seed: int = 42, id_prefix: str = "Point"
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Generate an initial frame with random points and unique IDs.
|
|
|
|
Parameters:
|
|
- num_points: int
|
|
Number of data points.
|
|
- num_features: int
|
|
Number of features per data point.
|
|
- seed: int
|
|
Random seed for reproducibility.
|
|
- id_prefix: str
|
|
Prefix for generating unique IDs.
|
|
|
|
Returns:
|
|
- df: pd.DataFrame
|
|
DataFrame with random data and unique IDs.
|
|
"""
|
|
np.random.seed(seed)
|
|
data = np.random.randn(num_points, num_features)
|
|
df = pd.DataFrame(data, columns=[f"feature_{j}" for j in range(num_features)])
|
|
df["id"] = [i for i in range(num_points)]
|
|
df["id"] = df["id"].astype(int)
|
|
df["time"] = 0
|
|
return df
|
|
|
|
|
|
def generate_jittered_snapshots(
|
|
initial_df: pd.DataFrame,
|
|
num_timesteps: int,
|
|
jitter_scale: float = 0.1,
|
|
seed: int = 42,
|
|
) -> List[pd.DataFrame]:
|
|
"""
|
|
Generate one jittered snapshot per timestep, with random point add/remove.
|
|
|
|
Parameters:
|
|
- initial_df: pd.DataFrame
|
|
The initial DataFrame to apply jitter.
|
|
- num_timesteps: int
|
|
Number of timesteps (one snapshot produced per timestep).
|
|
- jitter_scale: float
|
|
Standard deviation of the Gaussian noise added for jitter.
|
|
- seed: int
|
|
Random seed for reproducibility.
|
|
|
|
Returns:
|
|
- snapshots: List[pd.DataFrame]
|
|
List of jittered DataFrames with dynamic point introduction/removal.
|
|
"""
|
|
np.random.seed(seed)
|
|
snapshots = []
|
|
current_df = initial_df.copy()
|
|
|
|
for i in range(num_timesteps):
|
|
# Apply jitter (set to 0 for testing)
|
|
jitter = np.random.normal(
|
|
loc=0.0,
|
|
scale=jitter_scale,
|
|
size=(current_df.shape[0], current_df.shape[1] - 2),
|
|
)
|
|
jittered_features = current_df.iloc[:, :-2] + jitter # Exclude 'id' and 'time'
|
|
jittered_df = jittered_features.copy()
|
|
jittered_df["id"] = current_df["id"]
|
|
|
|
# Randomly decide to add or remove points
|
|
action = np.random.choice(["add", "remove", "none"], p=[0.5, 0.5, 0])
|
|
|
|
if action == "add":
|
|
# Add a new point with a unique integer ID
|
|
new_point = np.random.randn(1, current_df.shape[1] - 2)
|
|
new_id = current_df["id"].max() + 1
|
|
new_df = pd.DataFrame(
|
|
new_point,
|
|
columns=[f"feature_{j}" for j in range(current_df.shape[1] - 2)],
|
|
)
|
|
new_df["id"] = new_id
|
|
jittered_df = pd.concat([jittered_df, new_df], ignore_index=True)
|
|
|
|
elif action == "remove" and len(jittered_df) > 1:
|
|
# Remove a random point
|
|
remove_idx = np.random.choice(jittered_df.index)
|
|
jittered_df = jittered_df.drop(index=remove_idx).reset_index(drop=True)
|
|
|
|
# Assign time index
|
|
jittered_df["time"] = i + 1 # Start from 1
|
|
|
|
snapshots.append(jittered_df)
|
|
|
|
# Update current_df for next iteration
|
|
current_df = jittered_df.copy()
|
|
|
|
return snapshots
|