Improve training signal and add honest eval metrics

- prepare_training_data: bag_size 5→8, street signs fill slots first so
  every sample contains the most geographically discriminative texts
- train: HuberLoss replaces MSE (robust to outlier intersections),
  ReduceLROnPlateau scheduler added, split logic extracted to data_utils
- eval: reproduce train/val split to report honest per-bag and
  per-intersection-aggregated metrics separately for train and val sets
- data_utils: shared split_indices() so train and eval use identical splits

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Michael Pilosov 2026-05-25 22:41:14 +00:00
parent b72cd1b917
commit e9d3d72499
4 changed files with 151 additions and 37 deletions

27
data_utils.py Normal file
View File

@ -0,0 +1,27 @@
import numpy as np
from sklearn.model_selection import GroupShuffleSplit, train_test_split
DEFAULT_TEST_SIZE = 0.2
DEFAULT_SEED = 1992
def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED):
"""Return (train_indices, val_indices) arrays.
Uses a group-aware split on the ``intersection`` column when present so
that every row from the same intersection lands in the same partition.
Falls back to a plain random split otherwise.
"""
indices = np.arange(len(data))
if "intersection" in data.columns:
splitter = GroupShuffleSplit(
n_splits=1, test_size=test_size, random_state=seed
)
train_idx, val_idx = next(
splitter.split(indices, groups=data["intersection"])
)
else:
train_idx, val_idx = train_test_split(
indices, test_size=test_size, random_state=seed
)
return train_idx, val_idx

72
eval.py
View File

@ -11,6 +11,7 @@ import pandas as pd # noqa: E402
import torch # noqa: E402 import torch # noqa: E402
from sentence_transformers import SentenceTransformer # noqa: E402 from sentence_transformers import SentenceTransformer # noqa: E402
from data_utils import split_indices # noqa: E402
from train import CoordinateRegressor, haversine_km # noqa: E402 from train import CoordinateRegressor, haversine_km # noqa: E402
matplotlib.use("Agg") matplotlib.use("Agg")
@ -34,6 +35,18 @@ def parse_args():
help="Device to use for evaluation. Defaults to `cuda`.", help="Device to use for evaluation. Defaults to `cuda`.",
) )
parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument(
"--seed",
type=int,
default=1992,
help="Must match the seed used during training for a consistent split.",
)
parser.add_argument(
"--test-size",
type=float,
default=0.2,
help="Must match the test-size used during training.",
)
return parser.parse_args() return parser.parse_args()
@ -218,7 +231,10 @@ def main():
coord_std, coord_std,
) = load_model(args.model_path, device) ) = load_model(args.model_path, device)
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"]) data = pd.read_csv(args.data_file).dropna(
subset=["text", "latitude", "longitude"]
)
data = data.reset_index(drop=True)
texts = data["text"].astype(str).tolist() texts = data["text"].astype(str).tolist()
predicted = predict( predicted = predict(
trained_encoder, trained_encoder,
@ -248,6 +264,15 @@ def main():
results["initial_predicted_latitude"] = initial_predicted[:, 0] results["initial_predicted_latitude"] = initial_predicted[:, 0]
results["initial_predicted_longitude"] = initial_predicted[:, 1] results["initial_predicted_longitude"] = initial_predicted[:, 1]
results["error_km"] = errors results["error_km"] = errors
# Reproduce the same train/val split used during training.
train_indices, val_indices = split_indices(
data, test_size=args.test_size, seed=args.seed
)
val_mask = np.zeros(len(results), dtype=bool)
val_mask[val_indices] = True
results["split"] = np.where(val_mask, "val", "train")
results.to_csv(args.output_file, index=False) results.to_csv(args.output_file, index=False)
make_prediction_plot(results, args.plot_file) make_prediction_plot(results, args.plot_file)
make_predicted_vs_actual_plot(results, args.scatter_plot_file) make_predicted_vs_actual_plot(results, args.scatter_plot_file)
@ -255,9 +280,48 @@ def main():
print(f"Wrote predictions to {args.output_file}") print(f"Wrote predictions to {args.output_file}")
print(f"Wrote plot to {args.plot_file}") print(f"Wrote plot to {args.plot_file}")
print(f"Wrote scatter plot to {args.scatter_plot_file}") print(f"Wrote scatter plot to {args.scatter_plot_file}")
print(f"mean_error_km={errors.mean():.3f}")
print(f"median_error_km={np.median(errors):.3f}") def _fmt(errs):
print(f"p90_error_km={np.percentile(errors, 90):.3f}") return (
f"mean={np.mean(errs):.3f} "
f"median={np.median(errs):.3f} "
f"p90={np.percentile(errs, 90):.3f}"
)
def _intersection_agg_errors(df):
"""Average per-bag predictions to one coordinate per intersection."""
agg = (
df.groupby("intersection")
.agg(
pred_lat=("predicted_latitude", "mean"),
pred_lon=("predicted_longitude", "mean"),
true_lat=("latitude", "first"),
true_lon=("longitude", "first"),
)
.reset_index()
)
return haversine_km(
agg[["pred_lat", "pred_lon"]].to_numpy(dtype=np.float32),
agg[["true_lat", "true_lon"]].to_numpy(dtype=np.float32),
)
val_df = results[results["split"] == "val"]
train_df = results[results["split"] == "train"]
val_bag_errors = val_df["error_km"].to_numpy()
train_bag_errors = train_df["error_km"].to_numpy()
val_agg_errors = _intersection_agg_errors(val_df)
train_agg_errors = _intersection_agg_errors(train_df)
n_val = val_df["intersection"].nunique() if "intersection" in val_df.columns else "?"
n_train = train_df["intersection"].nunique() if "intersection" in train_df.columns else "?"
print()
print(f"[all bags ({len(errors):>6} rows)] {_fmt(errors)}")
print(f"[train bags({len(train_bag_errors):>6} rows)] {_fmt(train_bag_errors)}")
print(f"[val bags ({len(val_bag_errors):>6} rows)] {_fmt(val_bag_errors)}")
print(f"[train agg ({n_train:>3} intersections)] {_fmt(train_agg_errors)}")
print(f"[val agg ({n_val:>3} intersections)] {_fmt(val_agg_errors)} ← generalization")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,10 @@
import argparse import argparse
import random
import pandas as pd import pandas as pd
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"] BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
OPTIONAL_COLUMNS = ["code_type"]
EXCLUDED_INTERSECTIONS = {"56th-pena"} EXCLUDED_INTERSECTIONS = {"56th-pena"}
@ -31,8 +33,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--bag-size", "--bag-size",
type=int, type=int,
default=5, default=8,
help="Number of sign texts to sample for each training row.", help=(
"Total number of sign texts per training row. Street signs fill "
"slots first; remaining slots are randomly sampled from all signs."
),
) )
parser.add_argument( parser.add_argument(
"--samples-per-intersection", "--samples-per-intersection",
@ -50,12 +55,13 @@ def parse_args():
def load_raw_data(path): def load_raw_data(path):
data = pd.read_csv(path) data = pd.read_csv(path)
missing = sorted(set(INPUT_COLUMNS) - set(data.columns)) missing = sorted(set(BASE_COLUMNS) - set(data.columns))
if missing: if missing:
raise ValueError(f"Missing required columns: {missing}") raise ValueError(f"Missing required columns: {missing}")
data = data[INPUT_COLUMNS].copy() keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns]
data = data.dropna(subset=INPUT_COLUMNS) data = data[keep].copy()
data = data.dropna(subset=BASE_COLUMNS)
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip() data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
data = data[data["text_on_sign_exact"] != ""] data = data[data["text_on_sign_exact"] != ""]
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)] data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
@ -68,28 +74,48 @@ def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separato
if samples_per_intersection < 1: if samples_per_intersection < 1:
raise ValueError("--samples-per-intersection must be at least 1") raise ValueError("--samples-per-intersection must be at least 1")
has_code_type = "code_type" in data.columns
rows = [] rows = []
grouped = data.groupby("intersection", sort=True) grouped = data.groupby("intersection", sort=True)
for intersection, group in grouped: for intersection, group in grouped:
texts = group["text_on_sign_exact"].tolist()
latitude = group["latitude"].mean() latitude = group["latitude"].mean()
longitude = group["longitude"].mean() longitude = group["longitude"].mean()
all_texts = group["text_on_sign_exact"].tolist()
# Street signs are the most geographically discriminative; always
# include them first up to bag_size, then fill remaining slots randomly.
if has_code_type:
street_texts = group[group["code_type"] == "street_sign"][
"text_on_sign_exact"
].tolist()
else:
street_texts = []
for sample_id in range(samples_per_intersection): for sample_id in range(samples_per_intersection):
sampled = group.sample( rng_state = seed + len(rows)
n=bag_size,
replace=True, guaranteed = street_texts[:bag_size]
random_state=seed + len(rows), remaining = bag_size - len(guaranteed)
) if remaining > 0:
filler = group.sample(
n=remaining, replace=True, random_state=rng_state
)["text_on_sign_exact"].tolist()
else:
filler = []
bag = guaranteed + filler
# Shuffle so position in the sequence doesn't encode sign type.
random.Random(rng_state).shuffle(bag)
rows.append( rows.append(
{ {
"intersection": intersection, "intersection": intersection,
"sample_id": sample_id, "sample_id": sample_id,
"text": separator.join(sampled["text_on_sign_exact"].tolist()), "text": separator.join(bag),
"latitude": latitude, "latitude": latitude,
"longitude": longitude, "longitude": longitude,
"unique_sign_count": len(set(texts)), "unique_sign_count": len(set(all_texts)),
"raw_sign_count": len(texts), "raw_sign_count": len(all_texts),
} }
) )

View File

@ -8,7 +8,7 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from sentence_transformers import LoggingHandler, SentenceTransformer from sentence_transformers import LoggingHandler, SentenceTransformer
from sklearn.model_selection import GroupShuffleSplit, train_test_split from data_utils import split_indices
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@ -91,9 +91,9 @@ def parse_args():
"epoch). Validation still runs every epoch." "epoch). Validation still runs every epoch."
), ),
) )
parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--head-learning-rate", type=float, default=1e-3) parser.add_argument("--head-learning-rate", type=float, default=5e-2)
parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--weight-decay", type=float, default=0.001)
parser.add_argument("--test-size", type=float, default=0.2) parser.add_argument("--test-size", type=float, default=0.2)
parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--dropout", type=float, default=0.1)
@ -433,18 +433,9 @@ def main():
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32) coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates) normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
indices = np.arange(len(data)) train_indices, val_indices = split_indices(
if "intersection" in data.columns: data, test_size=args.test_size, seed=args.seed
splitter = GroupShuffleSplit( )
n_splits=1, test_size=args.test_size, random_state=args.seed
)
train_indices, val_indices = next(
splitter.split(indices, groups=data["intersection"])
)
else:
train_indices, val_indices = train_test_split(
indices, test_size=args.test_size, random_state=args.seed
)
train_dataset = SignCoordinateDataset( train_dataset = SignCoordinateDataset(
[texts[i] for i in train_indices], normalized_coordinates[train_indices] [texts[i] for i in train_indices], normalized_coordinates[train_indices]
@ -515,7 +506,10 @@ def main():
) )
optimizer = make_optimizer(encoder, head, args) optimizer = make_optimizer(encoder, head, args)
loss_fn = nn.MSELoss() loss_fn = nn.HuberLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7
)
best_val_loss = float("inf") best_val_loss = float("inf")
best_states = None best_states = None
pending_save = False pending_save = False
@ -552,9 +546,12 @@ def main():
coord_std, coord_std,
device, device,
) )
scheduler.step(val_loss)
current_lr = optimizer.param_groups[-1]["lr"]
print( print(
f"epoch={epoch} train_loss={train_loss:.6f} " f"epoch={epoch} train_loss={train_loss:.6f} "
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}" f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} "
f"lr={current_lr:.2e}"
) )
if val_loss < best_val_loss: if val_loss < best_val_loss: