citybert/data_utils.py
Michael Pilosov e9d3d72499 Improve training signal and add honest eval metrics
- prepare_training_data: bag_size 5→8, street signs fill slots first so
  every sample contains the most geographically discriminative texts
- train: HuberLoss replaces MSE (robust to outlier intersections),
  ReduceLROnPlateau scheduler added, split logic extracted to data_utils
- eval: reproduce train/val split to report honest per-bag and
  per-intersection-aggregated metrics separately for train and val sets
- data_utils: shared split_indices() so train and eval use identical splits

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 22:41:14 +00:00

28 lines
924 B
Python

import numpy as np
from sklearn.model_selection import GroupShuffleSplit, train_test_split
DEFAULT_TEST_SIZE = 0.2
DEFAULT_SEED = 1992
def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED):
"""Return (train_indices, val_indices) arrays.
Uses a group-aware split on the ``intersection`` column when present so
that every row from the same intersection lands in the same partition.
Falls back to a plain random split otherwise.
"""
indices = np.arange(len(data))
if "intersection" in data.columns:
splitter = GroupShuffleSplit(
n_splits=1, test_size=test_size, random_state=seed
)
train_idx, val_idx = next(
splitter.split(indices, groups=data["intersection"])
)
else:
train_idx, val_idx = train_test_split(
indices, test_size=test_size, random_state=seed
)
return train_idx, val_idx