From c12d2cda6c043a3f40e23dcbf770ba054ae16f92 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Wed, 22 Apr 2026 17:04:50 -0600 Subject: [PATCH] flow: hash user-supplied generator_kwargs, not the merged dict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The flow previously merged _DEFAULT_GENERATOR_KWARGS={random_state:0} and n_samples=num_points into generator_kwargs BEFORE hashing. Prefect only records the user-supplied form, so the web app's synth_output_paths disagreed with the flow's output name — a plain swiss_roll run showed 'embedding: n/a' in the runs list despite completing, because the web looked for the hash that excluded those defaults. Now we keep the user-supplied generator_kwargs around for hashing + metadata, and use the merged dict only for the actual generator call. n_samples is already captured in the stem as 'N', and random_state=0 is a flow constant — neither belongs in the semantic identity. --- flows/embedding_flow.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/flows/embedding_flow.py b/flows/embedding_flow.py index 9ebe7a2..3d638ec 100644 --- a/flows/embedding_flow.py +++ b/flows/embedding_flow.py @@ -296,9 +296,16 @@ def embedding_flow( reference_speedup: float = 10.0, samples: int = 10_000, ): + # Preserve the user-supplied generator_kwargs for hashing / metadata — + # the merged dict (with random_state defaults + n_samples) goes to the + # generator itself but those aren't part of the run's semantic identity + # (random_state=0 is a flow constant; n_samples is captured as `N` in + # the stem). If the merged dict were hashed, the web app would disagree + # with the flow because Prefect only records the user-supplied form. + user_generator_kwargs = dict(generator_kwargs or {}) generator_kwargs = { **_DEFAULT_GENERATOR_KWARGS, - **(generator_kwargs or {}), + **user_generator_kwargs, "n_samples": num_points, } embed_columns = ( @@ -311,7 +318,7 @@ def embedding_flow( output_ref: str = ( f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" ) - _args_tag = _run_args_hash(embed_args, generator_kwargs) + _args_tag = _run_args_hash(embed_args, user_generator_kwargs) output_embed: str = ( f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html" ) @@ -405,7 +412,7 @@ def embedding_flow( "jitter_scale": jitter_scale, "seed": seed, "generator_path": generator_path, - "generator_kwargs": generator_kwargs or {}, + "generator_kwargs": user_generator_kwargs, "embedder": embedder, "embed_args": merged_embed_args, }, @@ -428,7 +435,7 @@ def embedding_flow( frames = parse_plotly_run(emb_path_result) # Persist generator_kwargs so the server's label enrichment can # regenerate the correct dataset variant (swiss_roll vs hole). - frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {} + frames.setdefault("meta", {})["generator_kwargs"] = user_generator_kwargs Path(output_frames).write_text( json.dumps(frames, separators=(",", ":")), encoding="utf-8" )