Improve training signal and add honest eval metrics
- prepare_training_data: bag_size 5→8, street signs fill slots first so every sample contains the most geographically discriminative texts - train: HuberLoss replaces MSE (robust to outlier intersections), ReduceLROnPlateau scheduler added, split logic extracted to data_utils - eval: reproduce train/val split to report honest per-bag and per-intersection-aggregated metrics separately for train and val sets - data_utils: shared split_indices() so train and eval use identical splits Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
b72cd1b917
commit
e9d3d72499
27
data_utils.py
Normal file
27
data_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
||||
|
||||
DEFAULT_TEST_SIZE = 0.2
|
||||
DEFAULT_SEED = 1992
|
||||
|
||||
|
||||
def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED):
|
||||
"""Return (train_indices, val_indices) arrays.
|
||||
|
||||
Uses a group-aware split on the ``intersection`` column when present so
|
||||
that every row from the same intersection lands in the same partition.
|
||||
Falls back to a plain random split otherwise.
|
||||
"""
|
||||
indices = np.arange(len(data))
|
||||
if "intersection" in data.columns:
|
||||
splitter = GroupShuffleSplit(
|
||||
n_splits=1, test_size=test_size, random_state=seed
|
||||
)
|
||||
train_idx, val_idx = next(
|
||||
splitter.split(indices, groups=data["intersection"])
|
||||
)
|
||||
else:
|
||||
train_idx, val_idx = train_test_split(
|
||||
indices, test_size=test_size, random_state=seed
|
||||
)
|
||||
return train_idx, val_idx
|
||||
72
eval.py
72
eval.py
@ -11,6 +11,7 @@ import pandas as pd # noqa: E402
|
||||
import torch # noqa: E402
|
||||
from sentence_transformers import SentenceTransformer # noqa: E402
|
||||
|
||||
from data_utils import split_indices # noqa: E402
|
||||
from train import CoordinateRegressor, haversine_km # noqa: E402
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@ -34,6 +35,18 @@ def parse_args():
|
||||
help="Device to use for evaluation. Defaults to `cuda`.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1992,
|
||||
help="Must match the seed used during training for a consistent split.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-size",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Must match the test-size used during training.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -218,7 +231,10 @@ def main():
|
||||
coord_std,
|
||||
) = load_model(args.model_path, device)
|
||||
|
||||
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"])
|
||||
data = pd.read_csv(args.data_file).dropna(
|
||||
subset=["text", "latitude", "longitude"]
|
||||
)
|
||||
data = data.reset_index(drop=True)
|
||||
texts = data["text"].astype(str).tolist()
|
||||
predicted = predict(
|
||||
trained_encoder,
|
||||
@ -248,6 +264,15 @@ def main():
|
||||
results["initial_predicted_latitude"] = initial_predicted[:, 0]
|
||||
results["initial_predicted_longitude"] = initial_predicted[:, 1]
|
||||
results["error_km"] = errors
|
||||
|
||||
# Reproduce the same train/val split used during training.
|
||||
train_indices, val_indices = split_indices(
|
||||
data, test_size=args.test_size, seed=args.seed
|
||||
)
|
||||
val_mask = np.zeros(len(results), dtype=bool)
|
||||
val_mask[val_indices] = True
|
||||
results["split"] = np.where(val_mask, "val", "train")
|
||||
|
||||
results.to_csv(args.output_file, index=False)
|
||||
make_prediction_plot(results, args.plot_file)
|
||||
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
|
||||
@ -255,9 +280,48 @@ def main():
|
||||
print(f"Wrote predictions to {args.output_file}")
|
||||
print(f"Wrote plot to {args.plot_file}")
|
||||
print(f"Wrote scatter plot to {args.scatter_plot_file}")
|
||||
print(f"mean_error_km={errors.mean():.3f}")
|
||||
print(f"median_error_km={np.median(errors):.3f}")
|
||||
print(f"p90_error_km={np.percentile(errors, 90):.3f}")
|
||||
|
||||
def _fmt(errs):
|
||||
return (
|
||||
f"mean={np.mean(errs):.3f} "
|
||||
f"median={np.median(errs):.3f} "
|
||||
f"p90={np.percentile(errs, 90):.3f}"
|
||||
)
|
||||
|
||||
def _intersection_agg_errors(df):
|
||||
"""Average per-bag predictions to one coordinate per intersection."""
|
||||
agg = (
|
||||
df.groupby("intersection")
|
||||
.agg(
|
||||
pred_lat=("predicted_latitude", "mean"),
|
||||
pred_lon=("predicted_longitude", "mean"),
|
||||
true_lat=("latitude", "first"),
|
||||
true_lon=("longitude", "first"),
|
||||
)
|
||||
.reset_index()
|
||||
)
|
||||
return haversine_km(
|
||||
agg[["pred_lat", "pred_lon"]].to_numpy(dtype=np.float32),
|
||||
agg[["true_lat", "true_lon"]].to_numpy(dtype=np.float32),
|
||||
)
|
||||
|
||||
val_df = results[results["split"] == "val"]
|
||||
train_df = results[results["split"] == "train"]
|
||||
|
||||
val_bag_errors = val_df["error_km"].to_numpy()
|
||||
train_bag_errors = train_df["error_km"].to_numpy()
|
||||
val_agg_errors = _intersection_agg_errors(val_df)
|
||||
train_agg_errors = _intersection_agg_errors(train_df)
|
||||
|
||||
n_val = val_df["intersection"].nunique() if "intersection" in val_df.columns else "?"
|
||||
n_train = train_df["intersection"].nunique() if "intersection" in train_df.columns else "?"
|
||||
|
||||
print()
|
||||
print(f"[all bags ({len(errors):>6} rows)] {_fmt(errors)}")
|
||||
print(f"[train bags({len(train_bag_errors):>6} rows)] {_fmt(train_bag_errors)}")
|
||||
print(f"[val bags ({len(val_bag_errors):>6} rows)] {_fmt(val_bag_errors)}")
|
||||
print(f"[train agg ({n_train:>3} intersections)] {_fmt(train_agg_errors)}")
|
||||
print(f"[val agg ({n_val:>3} intersections)] {_fmt(val_agg_errors)} ← generalization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
|
||||
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
||||
BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
||||
OPTIONAL_COLUMNS = ["code_type"]
|
||||
EXCLUDED_INTERSECTIONS = {"56th-pena"}
|
||||
|
||||
|
||||
@ -31,8 +33,11 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--bag-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of sign texts to sample for each training row.",
|
||||
default=8,
|
||||
help=(
|
||||
"Total number of sign texts per training row. Street signs fill "
|
||||
"slots first; remaining slots are randomly sampled from all signs."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--samples-per-intersection",
|
||||
@ -50,12 +55,13 @@ def parse_args():
|
||||
|
||||
def load_raw_data(path):
|
||||
data = pd.read_csv(path)
|
||||
missing = sorted(set(INPUT_COLUMNS) - set(data.columns))
|
||||
missing = sorted(set(BASE_COLUMNS) - set(data.columns))
|
||||
if missing:
|
||||
raise ValueError(f"Missing required columns: {missing}")
|
||||
|
||||
data = data[INPUT_COLUMNS].copy()
|
||||
data = data.dropna(subset=INPUT_COLUMNS)
|
||||
keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns]
|
||||
data = data[keep].copy()
|
||||
data = data.dropna(subset=BASE_COLUMNS)
|
||||
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
|
||||
data = data[data["text_on_sign_exact"] != ""]
|
||||
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
|
||||
@ -68,28 +74,48 @@ def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separato
|
||||
if samples_per_intersection < 1:
|
||||
raise ValueError("--samples-per-intersection must be at least 1")
|
||||
|
||||
has_code_type = "code_type" in data.columns
|
||||
rows = []
|
||||
grouped = data.groupby("intersection", sort=True)
|
||||
for intersection, group in grouped:
|
||||
texts = group["text_on_sign_exact"].tolist()
|
||||
latitude = group["latitude"].mean()
|
||||
longitude = group["longitude"].mean()
|
||||
all_texts = group["text_on_sign_exact"].tolist()
|
||||
|
||||
# Street signs are the most geographically discriminative; always
|
||||
# include them first up to bag_size, then fill remaining slots randomly.
|
||||
if has_code_type:
|
||||
street_texts = group[group["code_type"] == "street_sign"][
|
||||
"text_on_sign_exact"
|
||||
].tolist()
|
||||
else:
|
||||
street_texts = []
|
||||
|
||||
for sample_id in range(samples_per_intersection):
|
||||
sampled = group.sample(
|
||||
n=bag_size,
|
||||
replace=True,
|
||||
random_state=seed + len(rows),
|
||||
)
|
||||
rng_state = seed + len(rows)
|
||||
|
||||
guaranteed = street_texts[:bag_size]
|
||||
remaining = bag_size - len(guaranteed)
|
||||
if remaining > 0:
|
||||
filler = group.sample(
|
||||
n=remaining, replace=True, random_state=rng_state
|
||||
)["text_on_sign_exact"].tolist()
|
||||
else:
|
||||
filler = []
|
||||
|
||||
bag = guaranteed + filler
|
||||
# Shuffle so position in the sequence doesn't encode sign type.
|
||||
random.Random(rng_state).shuffle(bag)
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"intersection": intersection,
|
||||
"sample_id": sample_id,
|
||||
"text": separator.join(sampled["text_on_sign_exact"].tolist()),
|
||||
"text": separator.join(bag),
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"unique_sign_count": len(set(texts)),
|
||||
"raw_sign_count": len(texts),
|
||||
"unique_sign_count": len(set(all_texts)),
|
||||
"raw_sign_count": len(all_texts),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
31
train.py
31
train.py
@ -8,7 +8,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sentence_transformers import LoggingHandler, SentenceTransformer
|
||||
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
||||
from data_utils import split_indices
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
@ -91,9 +91,9 @@ def parse_args():
|
||||
"epoch). Validation still runs every epoch."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
||||
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
|
||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
||||
parser.add_argument("--head-learning-rate", type=float, default=5e-2)
|
||||
parser.add_argument("--weight-decay", type=float, default=0.001)
|
||||
parser.add_argument("--test-size", type=float, default=0.2)
|
||||
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||
parser.add_argument("--dropout", type=float, default=0.1)
|
||||
@ -433,17 +433,8 @@ def main():
|
||||
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
||||
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
|
||||
|
||||
indices = np.arange(len(data))
|
||||
if "intersection" in data.columns:
|
||||
splitter = GroupShuffleSplit(
|
||||
n_splits=1, test_size=args.test_size, random_state=args.seed
|
||||
)
|
||||
train_indices, val_indices = next(
|
||||
splitter.split(indices, groups=data["intersection"])
|
||||
)
|
||||
else:
|
||||
train_indices, val_indices = train_test_split(
|
||||
indices, test_size=args.test_size, random_state=args.seed
|
||||
train_indices, val_indices = split_indices(
|
||||
data, test_size=args.test_size, seed=args.seed
|
||||
)
|
||||
|
||||
train_dataset = SignCoordinateDataset(
|
||||
@ -515,7 +506,10 @@ def main():
|
||||
)
|
||||
|
||||
optimizer = make_optimizer(encoder, head, args)
|
||||
loss_fn = nn.MSELoss()
|
||||
loss_fn = nn.HuberLoss()
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7
|
||||
)
|
||||
best_val_loss = float("inf")
|
||||
best_states = None
|
||||
pending_save = False
|
||||
@ -552,9 +546,12 @@ def main():
|
||||
coord_std,
|
||||
device,
|
||||
)
|
||||
scheduler.step(val_loss)
|
||||
current_lr = optimizer.param_groups[-1]["lr"]
|
||||
print(
|
||||
f"epoch={epoch} train_loss={train_loss:.6f} "
|
||||
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"
|
||||
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} "
|
||||
f"lr={current_lr:.2e}"
|
||||
)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user