diff --git a/flows/embedding_flow.py b/flows/embedding_flow.py index 9ebe7a2..3d638ec 100644 --- a/flows/embedding_flow.py +++ b/flows/embedding_flow.py @@ -296,9 +296,16 @@ def embedding_flow( reference_speedup: float = 10.0, samples: int = 10_000, ): + # Preserve the user-supplied generator_kwargs for hashing / metadata — + # the merged dict (with random_state defaults + n_samples) goes to the + # generator itself but those aren't part of the run's semantic identity + # (random_state=0 is a flow constant; n_samples is captured as `N` in + # the stem). If the merged dict were hashed, the web app would disagree + # with the flow because Prefect only records the user-supplied form. + user_generator_kwargs = dict(generator_kwargs or {}) generator_kwargs = { **_DEFAULT_GENERATOR_KWARGS, - **(generator_kwargs or {}), + **user_generator_kwargs, "n_samples": num_points, } embed_columns = ( @@ -311,7 +318,7 @@ def embedding_flow( output_ref: str = ( f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" ) - _args_tag = _run_args_hash(embed_args, generator_kwargs) + _args_tag = _run_args_hash(embed_args, user_generator_kwargs) output_embed: str = ( f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html" ) @@ -405,7 +412,7 @@ def embedding_flow( "jitter_scale": jitter_scale, "seed": seed, "generator_path": generator_path, - "generator_kwargs": generator_kwargs or {}, + "generator_kwargs": user_generator_kwargs, "embedder": embedder, "embed_args": merged_embed_args, }, @@ -428,7 +435,7 @@ def embedding_flow( frames = parse_plotly_run(emb_path_result) # Persist generator_kwargs so the server's label enrichment can # regenerate the correct dataset variant (swiss_roll vs hole). - frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {} + frames.setdefault("meta", {})["generator_kwargs"] = user_generator_kwargs Path(output_frames).write_text( json.dumps(frames, separators=(",", ":")), encoding="utf-8" )