116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
import argparse
|
|
|
|
import pandas as pd
|
|
|
|
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Create bootstrapped sign-text bags for coordinate training."
|
|
)
|
|
parser.add_argument(
|
|
"-i",
|
|
"--input-file",
|
|
default="training_data_raw.csv",
|
|
help="Raw pandas-exported CSV.",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output-file",
|
|
default="training.csv",
|
|
help="Prepared training CSV.",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=1992,
|
|
help="Random seed for deterministic bootstrapping.",
|
|
)
|
|
parser.add_argument(
|
|
"--bag-size",
|
|
type=int,
|
|
default=5,
|
|
help="Number of sign texts to sample for each training row.",
|
|
)
|
|
parser.add_argument(
|
|
"--samples-per-intersection",
|
|
type=int,
|
|
default=100,
|
|
help="Bootstrap samples to create for each intersection.",
|
|
)
|
|
parser.add_argument(
|
|
"--separator",
|
|
default=" | ",
|
|
help="Separator used when joining sampled sign texts.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_raw_data(path):
|
|
data = pd.read_csv(path)
|
|
missing = sorted(set(INPUT_COLUMNS) - set(data.columns))
|
|
if missing:
|
|
raise ValueError(f"Missing required columns: {missing}")
|
|
|
|
data = data[INPUT_COLUMNS].copy()
|
|
data = data.dropna(subset=INPUT_COLUMNS)
|
|
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
|
|
data = data[data["text_on_sign_exact"] != ""]
|
|
return data
|
|
|
|
|
|
def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separator):
|
|
if bag_size < 1:
|
|
raise ValueError("--bag-size must be at least 1")
|
|
if samples_per_intersection < 1:
|
|
raise ValueError("--samples-per-intersection must be at least 1")
|
|
|
|
rows = []
|
|
grouped = data.groupby("intersection", sort=True)
|
|
for intersection, group in grouped:
|
|
texts = group["text_on_sign_exact"].tolist()
|
|
latitude = group["latitude"].mean()
|
|
longitude = group["longitude"].mean()
|
|
|
|
for sample_id in range(samples_per_intersection):
|
|
sampled = group.sample(
|
|
n=bag_size,
|
|
replace=True,
|
|
random_state=seed + len(rows),
|
|
)
|
|
rows.append(
|
|
{
|
|
"intersection": intersection,
|
|
"sample_id": sample_id,
|
|
"text": separator.join(sampled["text_on_sign_exact"].tolist()),
|
|
"latitude": latitude,
|
|
"longitude": longitude,
|
|
"unique_sign_count": len(set(texts)),
|
|
"raw_sign_count": len(texts),
|
|
}
|
|
)
|
|
|
|
return pd.DataFrame(rows)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
raw_data = load_raw_data(args.input_file)
|
|
training_data = make_bootstrap_data(
|
|
raw_data,
|
|
bag_size=args.bag_size,
|
|
samples_per_intersection=args.samples_per_intersection,
|
|
seed=args.seed,
|
|
separator=args.separator,
|
|
)
|
|
training_data.to_csv(args.output_file, index=False)
|
|
print(
|
|
f"Wrote {len(training_data):,} rows from "
|
|
f"{raw_data['intersection'].nunique():,} intersections to {args.output_file}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|