citybert/prepare_training_data.py
2026-05-25 15:16:19 -06:00

116 lines
3.3 KiB
Python

import argparse
import pandas as pd
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
def parse_args():
parser = argparse.ArgumentParser(
description="Create bootstrapped sign-text bags for coordinate training."
)
parser.add_argument(
"-i",
"--input-file",
default="training_data_raw.csv",
help="Raw pandas-exported CSV.",
)
parser.add_argument(
"-o",
"--output-file",
default="training.csv",
help="Prepared training CSV.",
)
parser.add_argument(
"--seed",
type=int,
default=1992,
help="Random seed for deterministic bootstrapping.",
)
parser.add_argument(
"--bag-size",
type=int,
default=5,
help="Number of sign texts to sample for each training row.",
)
parser.add_argument(
"--samples-per-intersection",
type=int,
default=100,
help="Bootstrap samples to create for each intersection.",
)
parser.add_argument(
"--separator",
default=" | ",
help="Separator used when joining sampled sign texts.",
)
return parser.parse_args()
def load_raw_data(path):
data = pd.read_csv(path)
missing = sorted(set(INPUT_COLUMNS) - set(data.columns))
if missing:
raise ValueError(f"Missing required columns: {missing}")
data = data[INPUT_COLUMNS].copy()
data = data.dropna(subset=INPUT_COLUMNS)
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
data = data[data["text_on_sign_exact"] != ""]
return data
def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separator):
if bag_size < 1:
raise ValueError("--bag-size must be at least 1")
if samples_per_intersection < 1:
raise ValueError("--samples-per-intersection must be at least 1")
rows = []
grouped = data.groupby("intersection", sort=True)
for intersection, group in grouped:
texts = group["text_on_sign_exact"].tolist()
latitude = group["latitude"].mean()
longitude = group["longitude"].mean()
for sample_id in range(samples_per_intersection):
sampled = group.sample(
n=bag_size,
replace=True,
random_state=seed + len(rows),
)
rows.append(
{
"intersection": intersection,
"sample_id": sample_id,
"text": separator.join(sampled["text_on_sign_exact"].tolist()),
"latitude": latitude,
"longitude": longitude,
"unique_sign_count": len(set(texts)),
"raw_sign_count": len(texts),
}
)
return pd.DataFrame(rows)
def main():
args = parse_args()
raw_data = load_raw_data(args.input_file)
training_data = make_bootstrap_data(
raw_data,
bag_size=args.bag_size,
samples_per_intersection=args.samples_per_intersection,
seed=args.seed,
separator=args.separator,
)
training_data.to_csv(args.output_file, index=False)
print(
f"Wrote {len(training_data):,} rows from "
f"{raw_data['intersection'].nunique():,} intersections to {args.output_file}"
)
if __name__ == "__main__":
main()