From e9d3d724993b39fc40f3a4db723723c521921af5 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 25 May 2026 22:41:14 +0000 Subject: [PATCH] Improve training signal and add honest eval metrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - prepare_training_data: bag_size 5→8, street signs fill slots first so every sample contains the most geographically discriminative texts - train: HuberLoss replaces MSE (robust to outlier intersections), ReduceLROnPlateau scheduler added, split logic extracted to data_utils - eval: reproduce train/val split to report honest per-bag and per-intersection-aggregated metrics separately for train and val sets - data_utils: shared split_indices() so train and eval use identical splits Co-authored-by: Cursor --- data_utils.py | 27 +++++++++++++++ eval.py | 72 +++++++++++++++++++++++++++++++++++++--- prepare_training_data.py | 56 ++++++++++++++++++++++--------- train.py | 33 +++++++++--------- 4 files changed, 151 insertions(+), 37 deletions(-) create mode 100644 data_utils.py diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000..477ec74 --- /dev/null +++ b/data_utils.py @@ -0,0 +1,27 @@ +import numpy as np +from sklearn.model_selection import GroupShuffleSplit, train_test_split + +DEFAULT_TEST_SIZE = 0.2 +DEFAULT_SEED = 1992 + + +def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED): + """Return (train_indices, val_indices) arrays. + + Uses a group-aware split on the ``intersection`` column when present so + that every row from the same intersection lands in the same partition. + Falls back to a plain random split otherwise. + """ + indices = np.arange(len(data)) + if "intersection" in data.columns: + splitter = GroupShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + train_idx, val_idx = next( + splitter.split(indices, groups=data["intersection"]) + ) + else: + train_idx, val_idx = train_test_split( + indices, test_size=test_size, random_state=seed + ) + return train_idx, val_idx diff --git a/eval.py b/eval.py index 9d57c82..3d24d8c 100644 --- a/eval.py +++ b/eval.py @@ -11,6 +11,7 @@ import pandas as pd # noqa: E402 import torch # noqa: E402 from sentence_transformers import SentenceTransformer # noqa: E402 +from data_utils import split_indices # noqa: E402 from train import CoordinateRegressor, haversine_km # noqa: E402 matplotlib.use("Agg") @@ -34,6 +35,18 @@ def parse_args(): help="Device to use for evaluation. Defaults to `cuda`.", ) parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--seed", + type=int, + default=1992, + help="Must match the seed used during training for a consistent split.", + ) + parser.add_argument( + "--test-size", + type=float, + default=0.2, + help="Must match the test-size used during training.", + ) return parser.parse_args() @@ -218,7 +231,10 @@ def main(): coord_std, ) = load_model(args.model_path, device) - data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"]) + data = pd.read_csv(args.data_file).dropna( + subset=["text", "latitude", "longitude"] + ) + data = data.reset_index(drop=True) texts = data["text"].astype(str).tolist() predicted = predict( trained_encoder, @@ -248,6 +264,15 @@ def main(): results["initial_predicted_latitude"] = initial_predicted[:, 0] results["initial_predicted_longitude"] = initial_predicted[:, 1] results["error_km"] = errors + + # Reproduce the same train/val split used during training. + train_indices, val_indices = split_indices( + data, test_size=args.test_size, seed=args.seed + ) + val_mask = np.zeros(len(results), dtype=bool) + val_mask[val_indices] = True + results["split"] = np.where(val_mask, "val", "train") + results.to_csv(args.output_file, index=False) make_prediction_plot(results, args.plot_file) make_predicted_vs_actual_plot(results, args.scatter_plot_file) @@ -255,9 +280,48 @@ def main(): print(f"Wrote predictions to {args.output_file}") print(f"Wrote plot to {args.plot_file}") print(f"Wrote scatter plot to {args.scatter_plot_file}") - print(f"mean_error_km={errors.mean():.3f}") - print(f"median_error_km={np.median(errors):.3f}") - print(f"p90_error_km={np.percentile(errors, 90):.3f}") + + def _fmt(errs): + return ( + f"mean={np.mean(errs):.3f} " + f"median={np.median(errs):.3f} " + f"p90={np.percentile(errs, 90):.3f}" + ) + + def _intersection_agg_errors(df): + """Average per-bag predictions to one coordinate per intersection.""" + agg = ( + df.groupby("intersection") + .agg( + pred_lat=("predicted_latitude", "mean"), + pred_lon=("predicted_longitude", "mean"), + true_lat=("latitude", "first"), + true_lon=("longitude", "first"), + ) + .reset_index() + ) + return haversine_km( + agg[["pred_lat", "pred_lon"]].to_numpy(dtype=np.float32), + agg[["true_lat", "true_lon"]].to_numpy(dtype=np.float32), + ) + + val_df = results[results["split"] == "val"] + train_df = results[results["split"] == "train"] + + val_bag_errors = val_df["error_km"].to_numpy() + train_bag_errors = train_df["error_km"].to_numpy() + val_agg_errors = _intersection_agg_errors(val_df) + train_agg_errors = _intersection_agg_errors(train_df) + + n_val = val_df["intersection"].nunique() if "intersection" in val_df.columns else "?" + n_train = train_df["intersection"].nunique() if "intersection" in train_df.columns else "?" + + print() + print(f"[all bags ({len(errors):>6} rows)] {_fmt(errors)}") + print(f"[train bags({len(train_bag_errors):>6} rows)] {_fmt(train_bag_errors)}") + print(f"[val bags ({len(val_bag_errors):>6} rows)] {_fmt(val_bag_errors)}") + print(f"[train agg ({n_train:>3} intersections)] {_fmt(train_agg_errors)}") + print(f"[val agg ({n_val:>3} intersections)] {_fmt(val_agg_errors)} ← generalization") if __name__ == "__main__": diff --git a/prepare_training_data.py b/prepare_training_data.py index cdabc8c..849ef39 100644 --- a/prepare_training_data.py +++ b/prepare_training_data.py @@ -1,8 +1,10 @@ import argparse +import random import pandas as pd -INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"] +BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"] +OPTIONAL_COLUMNS = ["code_type"] EXCLUDED_INTERSECTIONS = {"56th-pena"} @@ -31,8 +33,11 @@ def parse_args(): parser.add_argument( "--bag-size", type=int, - default=5, - help="Number of sign texts to sample for each training row.", + default=8, + help=( + "Total number of sign texts per training row. Street signs fill " + "slots first; remaining slots are randomly sampled from all signs." + ), ) parser.add_argument( "--samples-per-intersection", @@ -50,12 +55,13 @@ def parse_args(): def load_raw_data(path): data = pd.read_csv(path) - missing = sorted(set(INPUT_COLUMNS) - set(data.columns)) + missing = sorted(set(BASE_COLUMNS) - set(data.columns)) if missing: raise ValueError(f"Missing required columns: {missing}") - data = data[INPUT_COLUMNS].copy() - data = data.dropna(subset=INPUT_COLUMNS) + keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns] + data = data[keep].copy() + data = data.dropna(subset=BASE_COLUMNS) data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip() data = data[data["text_on_sign_exact"] != ""] data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)] @@ -68,28 +74,48 @@ def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separato if samples_per_intersection < 1: raise ValueError("--samples-per-intersection must be at least 1") + has_code_type = "code_type" in data.columns rows = [] grouped = data.groupby("intersection", sort=True) for intersection, group in grouped: - texts = group["text_on_sign_exact"].tolist() latitude = group["latitude"].mean() longitude = group["longitude"].mean() + all_texts = group["text_on_sign_exact"].tolist() + + # Street signs are the most geographically discriminative; always + # include them first up to bag_size, then fill remaining slots randomly. + if has_code_type: + street_texts = group[group["code_type"] == "street_sign"][ + "text_on_sign_exact" + ].tolist() + else: + street_texts = [] for sample_id in range(samples_per_intersection): - sampled = group.sample( - n=bag_size, - replace=True, - random_state=seed + len(rows), - ) + rng_state = seed + len(rows) + + guaranteed = street_texts[:bag_size] + remaining = bag_size - len(guaranteed) + if remaining > 0: + filler = group.sample( + n=remaining, replace=True, random_state=rng_state + )["text_on_sign_exact"].tolist() + else: + filler = [] + + bag = guaranteed + filler + # Shuffle so position in the sequence doesn't encode sign type. + random.Random(rng_state).shuffle(bag) + rows.append( { "intersection": intersection, "sample_id": sample_id, - "text": separator.join(sampled["text_on_sign_exact"].tolist()), + "text": separator.join(bag), "latitude": latitude, "longitude": longitude, - "unique_sign_count": len(set(texts)), - "raw_sign_count": len(texts), + "unique_sign_count": len(set(all_texts)), + "raw_sign_count": len(all_texts), } ) diff --git a/train.py b/train.py index 4adb10d..df10e8e 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd import torch from sentence_transformers import LoggingHandler, SentenceTransformer -from sklearn.model_selection import GroupShuffleSplit, train_test_split +from data_utils import split_indices from torch import nn from torch.utils.data import DataLoader, Dataset @@ -91,9 +91,9 @@ def parse_args(): "epoch). Validation still runs every epoch." ), ) - parser.add_argument("--learning-rate", type=float, default=2e-5) - parser.add_argument("--head-learning-rate", type=float, default=1e-3) - parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--learning-rate", type=float, default=1e-4) + parser.add_argument("--head-learning-rate", type=float, default=5e-2) + parser.add_argument("--weight-decay", type=float, default=0.001) parser.add_argument("--test-size", type=float, default=0.2) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) @@ -433,18 +433,9 @@ def main(): coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32) normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates) - indices = np.arange(len(data)) - if "intersection" in data.columns: - splitter = GroupShuffleSplit( - n_splits=1, test_size=args.test_size, random_state=args.seed - ) - train_indices, val_indices = next( - splitter.split(indices, groups=data["intersection"]) - ) - else: - train_indices, val_indices = train_test_split( - indices, test_size=args.test_size, random_state=args.seed - ) + train_indices, val_indices = split_indices( + data, test_size=args.test_size, seed=args.seed + ) train_dataset = SignCoordinateDataset( [texts[i] for i in train_indices], normalized_coordinates[train_indices] @@ -515,7 +506,10 @@ def main(): ) optimizer = make_optimizer(encoder, head, args) - loss_fn = nn.MSELoss() + loss_fn = nn.HuberLoss() + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7 + ) best_val_loss = float("inf") best_states = None pending_save = False @@ -552,9 +546,12 @@ def main(): coord_std, device, ) + scheduler.step(val_loss) + current_lr = optimizer.param_groups[-1]["lr"] print( f"epoch={epoch} train_loss={train_loss:.6f} " - f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}" + f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} " + f"lr={current_lr:.2e}" ) if val_loss < best_val_loss: