citybert/prepare_training_data.py
Michael Pilosov e9d3d72499 Improve training signal and add honest eval metrics
- prepare_training_data: bag_size 5→8, street signs fill slots first so
  every sample contains the most geographically discriminative texts
- train: HuberLoss replaces MSE (robust to outlier intersections),
  ReduceLROnPlateau scheduler added, split logic extracted to data_utils
- eval: reproduce train/val split to report honest per-bag and
  per-intersection-aggregated metrics separately for train and val sets
- data_utils: shared split_indices() so train and eval use identical splits

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 22:41:14 +00:00

144 lines
4.4 KiB
Python

import argparse
import random
import pandas as pd
BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
OPTIONAL_COLUMNS = ["code_type"]
EXCLUDED_INTERSECTIONS = {"56th-pena"}
def parse_args():
parser = argparse.ArgumentParser(
description="Create bootstrapped sign-text bags for coordinate training."
)
parser.add_argument(
"-i",
"--input-file",
default="training_data_raw.csv",
help="Raw pandas-exported CSV.",
)
parser.add_argument(
"-o",
"--output-file",
default="training.csv",
help="Prepared training CSV.",
)
parser.add_argument(
"--seed",
type=int,
default=1992,
help="Random seed for deterministic bootstrapping.",
)
parser.add_argument(
"--bag-size",
type=int,
default=8,
help=(
"Total number of sign texts per training row. Street signs fill "
"slots first; remaining slots are randomly sampled from all signs."
),
)
parser.add_argument(
"--samples-per-intersection",
type=int,
default=100,
help="Bootstrap samples to create for each intersection.",
)
parser.add_argument(
"--separator",
default=" | ",
help="Separator used when joining sampled sign texts.",
)
return parser.parse_args()
def load_raw_data(path):
data = pd.read_csv(path)
missing = sorted(set(BASE_COLUMNS) - set(data.columns))
if missing:
raise ValueError(f"Missing required columns: {missing}")
keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns]
data = data[keep].copy()
data = data.dropna(subset=BASE_COLUMNS)
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
data = data[data["text_on_sign_exact"] != ""]
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
return data
def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separator):
if bag_size < 1:
raise ValueError("--bag-size must be at least 1")
if samples_per_intersection < 1:
raise ValueError("--samples-per-intersection must be at least 1")
has_code_type = "code_type" in data.columns
rows = []
grouped = data.groupby("intersection", sort=True)
for intersection, group in grouped:
latitude = group["latitude"].mean()
longitude = group["longitude"].mean()
all_texts = group["text_on_sign_exact"].tolist()
# Street signs are the most geographically discriminative; always
# include them first up to bag_size, then fill remaining slots randomly.
if has_code_type:
street_texts = group[group["code_type"] == "street_sign"][
"text_on_sign_exact"
].tolist()
else:
street_texts = []
for sample_id in range(samples_per_intersection):
rng_state = seed + len(rows)
guaranteed = street_texts[:bag_size]
remaining = bag_size - len(guaranteed)
if remaining > 0:
filler = group.sample(
n=remaining, replace=True, random_state=rng_state
)["text_on_sign_exact"].tolist()
else:
filler = []
bag = guaranteed + filler
# Shuffle so position in the sequence doesn't encode sign type.
random.Random(rng_state).shuffle(bag)
rows.append(
{
"intersection": intersection,
"sample_id": sample_id,
"text": separator.join(bag),
"latitude": latitude,
"longitude": longitude,
"unique_sign_count": len(set(all_texts)),
"raw_sign_count": len(all_texts),
}
)
return pd.DataFrame(rows)
def main():
args = parse_args()
raw_data = load_raw_data(args.input_file)
training_data = make_bootstrap_data(
raw_data,
bag_size=args.bag_size,
samples_per_intersection=args.samples_per_intersection,
seed=args.seed,
separator=args.separator,
)
training_data.to_csv(args.output_file, index=False)
print(
f"Wrote {len(training_data):,} rows from "
f"{raw_data['intersection'].nunique():,} intersections to {args.output_file}"
)
if __name__ == "__main__":
main()