# embedding_utils.py import importlib from typing import List, Optional, Type, Union import numpy as np import pandas as pd import plotly.express as px from plotly.graph_objects import Figure def dynamic_import(class_path: str) -> Type: """ Dynamically import a class from a given module path. Parameters: - class_path: str The full path to the class (e.g., 'sklearn.decomposition.PCA'). Returns: - cls: Type The imported class. Raises: - ImportError: If the module or class cannot be found. """ try: module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) return cls except (ImportError, AttributeError) as e: raise ImportError(f"Cannot import '{class_path}'. Error: {e}") def create_embedding_dataframe( snapshot: pd.DataFrame, embed_columns: List[str], embedding_algorithm_str: str = "sklearn.decomposition.PCA", embedding_kwargs: Optional[dict] = None, label_columns: Optional[List[str]] = None, id_column: Optional[str] = None, time_idx: Optional[Union[int, str]] = None, ) -> pd.DataFrame: """ Apply an embedding algorithm to a single snapshot and prepare the DataFrame. Parameters: - snapshot: pd.DataFrame The input data snapshot to embed. - time_idx: Optional[Union[int, str]] The time identifier for the snapshot (e.g., integer index or 'YYYYMMDD' string). - embedding_algorithm_str: str The full module path to the embedding class (e.g., 'sklearn.decomposition.PCA'). - embedding_kwargs: Optional[dict] Additional keyword arguments for the embedding algorithm. - label_columns: Optional[List[str]] List of column names to include in the tooltip labels. If None or empty, labels are empty. - id_column: Optional[str] Column name to use as a unique identifier. If None, a default integer ID is assigned. Returns: - embedded_df: pd.DataFrame DataFrame containing 'id', 'x', 'y', 'time', and 'label' columns. """ if embedding_kwargs is None: embedding_kwargs = {} if label_columns is None: label_columns = [] # Assign unique ID embedded_df = pd.DataFrame() if id_column and id_column in snapshot.columns: embedded_df["id"] = snapshot[id_column] # .astype(int) # Ensure ID is integer else: embedded_df["id"] = snapshot.index # .astype(int) # Default to integer index # Dynamically import the embedding class embedding_class = dynamic_import(embedding_algorithm_str) # Initialize and fit the embedding model model = embedding_class(**embedding_kwargs) embedded = model.fit_transform(snapshot[embed_columns].values) if embedded.shape[1] != 2: raise ValueError("Embedding must result in 2 dimensions.") embedded_coords = pd.DataFrame(embedded, columns=["x", "y"]) embedded_df = pd.concat([embedded_df, embedded_coords], axis=1) if time_idx is not None: embedded_df["time"] = time_idx else: # if not supplied, use "time" from snapshot embedded_df["time"] = snapshot["time"] # Create tooltip labels if label_columns: # Ensure the label columns exist missing_cols = [col for col in label_columns if col not in snapshot.columns] if missing_cols: raise ValueError(f"Label columns not found in snapshot: {missing_cols}") # Concatenate specified columns into a single string for the tooltip labels = snapshot[label_columns].astype(str).agg(" | ".join, axis=1) embedded_df["label"] = labels else: embedded_df["label"] = "" for k in ["id", "x", "y", "time"]: assert k in embedded_df.columns, k return embedded_df def collect_and_prepare_for_plotly( embedded_dfs: List[pd.DataFrame], sort_time: bool = True, id_column: str = "id" ) -> pd.DataFrame: """ Combine multiple embedded DataFrames and prepare them for Plotly visualization. Parameters: - embedded_dfs: List[pd.DataFrame] A list of DataFrames, each containing 'id', 'x', 'y', 'time', and 'label' columns. - sort_time: bool Whether to sort the combined DataFrame by the 'time' column and then by 'id'. Returns: - combined_df: pd.DataFrame A single DataFrame concatenating all embedded snapshots, sorted by time and id if specified. """ if not embedded_dfs: raise ValueError("The list of embedded DataFrames is empty.") # Concatenate all embedded DataFrames combined_df = pd.concat(embedded_dfs, ignore_index=True) if "id" not in combined_df.columns: if id_column in combined_df.columns: # rename column to 'id' combined_df.rename(columns={id_column: "id"}, inplace=True) else: raise ValueError( "Each embedded DataFrame must contain an 'id' column for sorting." ) # Sort by 'time' and 'id' if required if sort_time: # Determine if 'time' is numeric or string for appropriate sorting # if pd.api.types.is_numeric_dtype(combined_df["time"]): # combined_df = combined_df.sort_values(by=["time", "id"]) # else: # Assume string dates are sortable (e.g., 'YYYYMMDD') combined_df = combined_df.sort_values(by=["time", "id"]) # Reset index after sorting combined_df.reset_index(drop=True, inplace=True) return combined_df def plot_embedding_over_time( combined_df: pd.DataFrame, title: str = "Embedding Over Time", color_column: Optional[str] = None, fixed_axes: bool = True, equal_aspect: bool = True, frame_duration: int = 500, transition_duration: int = 500, samples: int = 0, ) -> Figure: """ Create an interactive Plotly scatter plot with animation over time. Parameters: - combined_df: pd.DataFrame DataFrame containing at least 'id', 'time', and numerical feature columns. - title: str Title of the plot. - color_column: Optional[str] Column name for color encoding. If None, no color encoding is applied. - fixed_axes: bool If True, axes ranges are fixed across all frames for consistency. - equal_aspect: bool If True, the plot will have an equal aspect ratio. - frame_duration: int Duration of each animation frame in milliseconds. - transition_duration: int Duration of the transition between frames in milliseconds. - samples: int (optional) Number of samples to use for plotting (for faster rendering). Returns: - fig: plotly.graph_objs._figure.Figure The Plotly figure object. """ # Step 1: Identify numerical columns excluding 'id' and 'time' numeric_columns = combined_df.select_dtypes( include=["float", "int", "bool"] ).columns.tolist() numeric_columns = [col for col in numeric_columns if col not in ["id", "time"]] if len(numeric_columns) < 2: raise ValueError( "DataFrame must have at least two numerical columns for x and y axes." ) # Step 2: Use the first two numerical columns as default x and y default_x = numeric_columns[0] default_y = numeric_columns[1] # Step 3: Sample the data if required hover_data = ( [default_x, default_y] if "label" in combined_df.columns else [default_x, default_y] ) hover_data = ["id"] if samples > 0: unique_ids = combined_df["id"].unique().tolist() samples = min(samples, len(unique_ids)) sample_ids = np.random.choice(unique_ids, samples, replace=False) combined_df_sample = combined_df[combined_df["id"].isin(sample_ids)] else: combined_df_sample = combined_df # Step 4: Determine opacity based on number of unique IDs opacity = max(0.1, min(5000.0 / combined_df_sample["id"].nunique(), 1)) # Step 5: Create the initial scatter plot using Plotly Express if color_column and color_column in combined_df.columns: fig = px.scatter( combined_df_sample, x=default_x, y=default_y, animation_frame="time", animation_group="id", color=color_column, hover_data=hover_data, title=title, labels={default_x: "x", default_y: "y", "time": "Time"}, category_orders={"time": sorted(combined_df["time"].unique())}, opacity=opacity, ) else: fig = px.scatter( combined_df_sample, x=default_x, y=default_y, animation_frame="time", animation_group="id", hover_data=hover_data, title=title, labels={default_x: "x", default_y: "y", "time": "Time"}, category_orders={"time": sorted(combined_df["time"].unique())}, opacity=opacity, ) # Step 6: Fix axes ranges if required if fixed_axes: x_min, x_max = ( combined_df_sample[default_x].min(), combined_df_sample[default_x].max(), ) y_min, y_max = ( combined_df_sample[default_y].min(), combined_df_sample[default_y].max(), ) fig.update_layout( xaxis=dict(range=[x_min, x_max]), yaxis=dict(range=[y_min, y_max]), ) # Step 7: Enforce equal aspect ratio if required if equal_aspect: fig.update_yaxes(scaleanchor="x", scaleratio=1) # Step 8: Prepare dropdowns if there are more than two numerical columns if len(numeric_columns) > 2: # Create dropdown options dropdown_options = [ {"label": col.replace("_", " ").title(), "value": col} for col in numeric_columns ] # Dropdown for X-axis dropdown_x = dict( active=0, buttons=[ dict( label=option["label"], method="update", args=[ {"x": [combined_df_sample[option["value"]]]}, { "xaxis.title.text": option["label"], "hover_data": hover_data, # 'hover_data': [option['value'], default_y] + hover_data }, # rescale axis { "xaxis": { "range": [ combined_df_sample[option["value"]].min(), combined_df_sample[option["value"]].max(), ] } }, ], ) for option in dropdown_options ], direction="down", showactive=True, x=0.4, xanchor="left", y=1.1, yanchor="top", pad={"r": 10, "t": 10}, name="X-Axis", ) # Dropdown for Y-axis dropdown_y = dict( active=1, buttons=[ dict( label=option["label"], method="update", args=[ {"y": [combined_df_sample[option["value"]]]}, { "yaxis.title.text": option["label"], "hover_data": hover_data, # 'hover_data': [default_x, option['value']] + hover_data }, # rescale axis { "yaxis": { "range": [ combined_df_sample[option["value"]].min(), combined_df_sample[option["value"]].max(), ] } }, ], ) for option in dropdown_options ], direction="down", showactive=True, x=0.4, xanchor="left", y=1.2, yanchor="top", pad={"r": 10, "t": 10}, name="Y-Axis", ) # Step 9: Consolidate all layout updates in a single call fig.update_layout( # updatemenus=[ # {}, # dropdown_y, # dropdown_x, # ], xaxis_title=default_x.replace("_", " ").title(), yaxis_title=default_y.replace("_", " ").title(), width=800, height=800, margin=dict(t=100, b=150), ) else: # If only two numerical columns, set titles accordingly fig.update_layout( xaxis_title=default_x.replace("_", " ").title(), yaxis_title=default_y.replace("_", " ").title(), width=800, height=800, ) # # Step 10: Adjust animation durations for smoother transitions if fig.layout.updatemenus: for updatemenu in fig.layout.updatemenus: if "buttons" in updatemenu: for btn in updatemenu.buttons: if ( "args" in btn and len(btn.args) > 1 and isinstance(btn.args[1], dict) ): frame = btn.args[1].get("frame", {}) transition = btn.args[1].get("transition", {}) frame["duration"] = frame_duration transition["duration"] = transition_duration btn.args[1]["frame"] = frame btn.args[1]["transition"] = transition return fig def generate_initial_frame( num_points: int, num_features: int, seed: int = 42, id_prefix: str = "Point" ) -> pd.DataFrame: """ Generate an initial frame with random points and unique IDs. Parameters: - num_points: int Number of data points. - num_features: int Number of features per data point. - seed: int Random seed for reproducibility. - id_prefix: str Prefix for generating unique IDs. Returns: - df: pd.DataFrame DataFrame with random data and unique IDs. """ np.random.seed(seed) data = np.random.randn(num_points, num_features) df = pd.DataFrame(data, columns=[f"feature_{j}" for j in range(num_features)]) df["id"] = [i for i in range(num_points)] df["id"] = df["id"].astype(int) df["time"] = 0 return df def generate_jittered_snapshots( initial_df: pd.DataFrame, num_timesteps: int, jitter_scale: float = 0.1, seed: int = 42, ) -> List[pd.DataFrame]: """ Generate one jittered snapshot per timestep, with random point add/remove. Parameters: - initial_df: pd.DataFrame The initial DataFrame to apply jitter. - num_timesteps: int Number of timesteps (one snapshot produced per timestep). - jitter_scale: float Standard deviation of the Gaussian noise added for jitter. - seed: int Random seed for reproducibility. Returns: - snapshots: List[pd.DataFrame] List of jittered DataFrames with dynamic point introduction/removal. """ np.random.seed(seed) snapshots = [] current_df = initial_df.copy() for i in range(num_timesteps): # Apply jitter (set to 0 for testing) jitter = np.random.normal( loc=0.0, scale=jitter_scale, size=(current_df.shape[0], current_df.shape[1] - 2), ) jittered_features = current_df.iloc[:, :-2] + jitter # Exclude 'id' and 'time' jittered_df = jittered_features.copy() jittered_df["id"] = current_df["id"] # Randomly decide to add or remove points action = np.random.choice(["add", "remove", "none"], p=[0.5, 0.5, 0]) if action == "add": # Add a new point with a unique integer ID new_point = np.random.randn(1, current_df.shape[1] - 2) new_id = current_df["id"].max() + 1 new_df = pd.DataFrame( new_point, columns=[f"feature_{j}" for j in range(current_df.shape[1] - 2)], ) new_df["id"] = new_id jittered_df = pd.concat([jittered_df, new_df], ignore_index=True) elif action == "remove" and len(jittered_df) > 1: # Remove a random point remove_idx = np.random.choice(jittered_df.index) jittered_df = jittered_df.drop(index=remove_idx).reset_index(drop=True) # Assign time index jittered_df["time"] = i + 1 # Start from 1 snapshots.append(jittered_df) # Update current_df for next iteration current_df = jittered_df.copy() return snapshots