flow: hash user-supplied generator_kwargs, not the merged dict

The flow previously merged _DEFAULT_GENERATOR_KWARGS={random_state:0} and
n_samples=num_points into generator_kwargs BEFORE hashing. Prefect only
records the user-supplied form, so the web app's synth_output_paths
disagreed with the flow's output name — a plain swiss_roll run showed
'embedding: n/a' in the runs list despite completing, because the web
looked for the hash that excluded those defaults.

Now we keep the user-supplied generator_kwargs around for hashing +
metadata, and use the merged dict only for the actual generator call.
n_samples is already captured in the stem as 'N<n>', and random_state=0
is a flow constant — neither belongs in the semantic identity.
This commit is contained in:
Michael Pilosov 2026-04-22 17:04:50 -06:00
parent bdbebaa7e8
commit c12d2cda6c

View File

@ -296,9 +296,16 @@ def embedding_flow(
reference_speedup: float = 10.0, reference_speedup: float = 10.0,
samples: int = 10_000, samples: int = 10_000,
): ):
# Preserve the user-supplied generator_kwargs for hashing / metadata —
# the merged dict (with random_state defaults + n_samples) goes to the
# generator itself but those aren't part of the run's semantic identity
# (random_state=0 is a flow constant; n_samples is captured as `N` in
# the stem). If the merged dict were hashed, the web app would disagree
# with the flow because Prefect only records the user-supplied form.
user_generator_kwargs = dict(generator_kwargs or {})
generator_kwargs = { generator_kwargs = {
**_DEFAULT_GENERATOR_KWARGS, **_DEFAULT_GENERATOR_KWARGS,
**(generator_kwargs or {}), **user_generator_kwargs,
"n_samples": num_points, "n_samples": num_points,
} }
embed_columns = ( embed_columns = (
@ -311,7 +318,7 @@ def embedding_flow(
output_ref: str = ( output_ref: str = (
f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html" f"{output_dir.strip('/')}/{_generator}_Reference_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}.html"
) )
_args_tag = _run_args_hash(embed_args, generator_kwargs) _args_tag = _run_args_hash(embed_args, user_generator_kwargs)
output_embed: str = ( output_embed: str = (
f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html" f"{output_dir.strip('/')}/{_generator}_{embedder.split('.')[-1]}_N{num_points}_T{num_timesteps}_J{jitter_scale}_s{seed}_{_args_tag}.html"
) )
@ -405,7 +412,7 @@ def embedding_flow(
"jitter_scale": jitter_scale, "jitter_scale": jitter_scale,
"seed": seed, "seed": seed,
"generator_path": generator_path, "generator_path": generator_path,
"generator_kwargs": generator_kwargs or {}, "generator_kwargs": user_generator_kwargs,
"embedder": embedder, "embedder": embedder,
"embed_args": merged_embed_args, "embed_args": merged_embed_args,
}, },
@ -428,7 +435,7 @@ def embedding_flow(
frames = parse_plotly_run(emb_path_result) frames = parse_plotly_run(emb_path_result)
# Persist generator_kwargs so the server's label enrichment can # Persist generator_kwargs so the server's label enrichment can
# regenerate the correct dataset variant (swiss_roll vs hole). # regenerate the correct dataset variant (swiss_roll vs hole).
frames.setdefault("meta", {})["generator_kwargs"] = generator_kwargs or {} frames.setdefault("meta", {})["generator_kwargs"] = user_generator_kwargs
Path(output_frames).write_text( Path(output_frames).write_text(
json.dumps(frames, separators=(",", ":")), encoding="utf-8" json.dumps(frames, separators=(",", ":")), encoding="utf-8"
) )