Improve training signal and add honest eval metrics
- prepare_training_data: bag_size 5→8, street signs fill slots first so every sample contains the most geographically discriminative texts - train: HuberLoss replaces MSE (robust to outlier intersections), ReduceLROnPlateau scheduler added, split logic extracted to data_utils - eval: reproduce train/val split to report honest per-bag and per-intersection-aggregated metrics separately for train and val sets - data_utils: shared split_indices() so train and eval use identical splits Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
b72cd1b917
commit
e9d3d72499
27
data_utils.py
Normal file
27
data_utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import numpy as np
|
||||||
|
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
||||||
|
|
||||||
|
DEFAULT_TEST_SIZE = 0.2
|
||||||
|
DEFAULT_SEED = 1992
|
||||||
|
|
||||||
|
|
||||||
|
def split_indices(data, test_size=DEFAULT_TEST_SIZE, seed=DEFAULT_SEED):
|
||||||
|
"""Return (train_indices, val_indices) arrays.
|
||||||
|
|
||||||
|
Uses a group-aware split on the ``intersection`` column when present so
|
||||||
|
that every row from the same intersection lands in the same partition.
|
||||||
|
Falls back to a plain random split otherwise.
|
||||||
|
"""
|
||||||
|
indices = np.arange(len(data))
|
||||||
|
if "intersection" in data.columns:
|
||||||
|
splitter = GroupShuffleSplit(
|
||||||
|
n_splits=1, test_size=test_size, random_state=seed
|
||||||
|
)
|
||||||
|
train_idx, val_idx = next(
|
||||||
|
splitter.split(indices, groups=data["intersection"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_idx, val_idx = train_test_split(
|
||||||
|
indices, test_size=test_size, random_state=seed
|
||||||
|
)
|
||||||
|
return train_idx, val_idx
|
||||||
72
eval.py
72
eval.py
@ -11,6 +11,7 @@ import pandas as pd # noqa: E402
|
|||||||
import torch # noqa: E402
|
import torch # noqa: E402
|
||||||
from sentence_transformers import SentenceTransformer # noqa: E402
|
from sentence_transformers import SentenceTransformer # noqa: E402
|
||||||
|
|
||||||
|
from data_utils import split_indices # noqa: E402
|
||||||
from train import CoordinateRegressor, haversine_km # noqa: E402
|
from train import CoordinateRegressor, haversine_km # noqa: E402
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@ -34,6 +35,18 @@ def parse_args():
|
|||||||
help="Device to use for evaluation. Defaults to `cuda`.",
|
help="Device to use for evaluation. Defaults to `cuda`.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch-size", type=int, default=64)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=1992,
|
||||||
|
help="Must match the seed used during training for a consistent split.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-size",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Must match the test-size used during training.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -218,7 +231,10 @@ def main():
|
|||||||
coord_std,
|
coord_std,
|
||||||
) = load_model(args.model_path, device)
|
) = load_model(args.model_path, device)
|
||||||
|
|
||||||
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"])
|
data = pd.read_csv(args.data_file).dropna(
|
||||||
|
subset=["text", "latitude", "longitude"]
|
||||||
|
)
|
||||||
|
data = data.reset_index(drop=True)
|
||||||
texts = data["text"].astype(str).tolist()
|
texts = data["text"].astype(str).tolist()
|
||||||
predicted = predict(
|
predicted = predict(
|
||||||
trained_encoder,
|
trained_encoder,
|
||||||
@ -248,6 +264,15 @@ def main():
|
|||||||
results["initial_predicted_latitude"] = initial_predicted[:, 0]
|
results["initial_predicted_latitude"] = initial_predicted[:, 0]
|
||||||
results["initial_predicted_longitude"] = initial_predicted[:, 1]
|
results["initial_predicted_longitude"] = initial_predicted[:, 1]
|
||||||
results["error_km"] = errors
|
results["error_km"] = errors
|
||||||
|
|
||||||
|
# Reproduce the same train/val split used during training.
|
||||||
|
train_indices, val_indices = split_indices(
|
||||||
|
data, test_size=args.test_size, seed=args.seed
|
||||||
|
)
|
||||||
|
val_mask = np.zeros(len(results), dtype=bool)
|
||||||
|
val_mask[val_indices] = True
|
||||||
|
results["split"] = np.where(val_mask, "val", "train")
|
||||||
|
|
||||||
results.to_csv(args.output_file, index=False)
|
results.to_csv(args.output_file, index=False)
|
||||||
make_prediction_plot(results, args.plot_file)
|
make_prediction_plot(results, args.plot_file)
|
||||||
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
|
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
|
||||||
@ -255,9 +280,48 @@ def main():
|
|||||||
print(f"Wrote predictions to {args.output_file}")
|
print(f"Wrote predictions to {args.output_file}")
|
||||||
print(f"Wrote plot to {args.plot_file}")
|
print(f"Wrote plot to {args.plot_file}")
|
||||||
print(f"Wrote scatter plot to {args.scatter_plot_file}")
|
print(f"Wrote scatter plot to {args.scatter_plot_file}")
|
||||||
print(f"mean_error_km={errors.mean():.3f}")
|
|
||||||
print(f"median_error_km={np.median(errors):.3f}")
|
def _fmt(errs):
|
||||||
print(f"p90_error_km={np.percentile(errors, 90):.3f}")
|
return (
|
||||||
|
f"mean={np.mean(errs):.3f} "
|
||||||
|
f"median={np.median(errs):.3f} "
|
||||||
|
f"p90={np.percentile(errs, 90):.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _intersection_agg_errors(df):
|
||||||
|
"""Average per-bag predictions to one coordinate per intersection."""
|
||||||
|
agg = (
|
||||||
|
df.groupby("intersection")
|
||||||
|
.agg(
|
||||||
|
pred_lat=("predicted_latitude", "mean"),
|
||||||
|
pred_lon=("predicted_longitude", "mean"),
|
||||||
|
true_lat=("latitude", "first"),
|
||||||
|
true_lon=("longitude", "first"),
|
||||||
|
)
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
return haversine_km(
|
||||||
|
agg[["pred_lat", "pred_lon"]].to_numpy(dtype=np.float32),
|
||||||
|
agg[["true_lat", "true_lon"]].to_numpy(dtype=np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
val_df = results[results["split"] == "val"]
|
||||||
|
train_df = results[results["split"] == "train"]
|
||||||
|
|
||||||
|
val_bag_errors = val_df["error_km"].to_numpy()
|
||||||
|
train_bag_errors = train_df["error_km"].to_numpy()
|
||||||
|
val_agg_errors = _intersection_agg_errors(val_df)
|
||||||
|
train_agg_errors = _intersection_agg_errors(train_df)
|
||||||
|
|
||||||
|
n_val = val_df["intersection"].nunique() if "intersection" in val_df.columns else "?"
|
||||||
|
n_train = train_df["intersection"].nunique() if "intersection" in train_df.columns else "?"
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"[all bags ({len(errors):>6} rows)] {_fmt(errors)}")
|
||||||
|
print(f"[train bags({len(train_bag_errors):>6} rows)] {_fmt(train_bag_errors)}")
|
||||||
|
print(f"[val bags ({len(val_bag_errors):>6} rows)] {_fmt(val_bag_errors)}")
|
||||||
|
print(f"[train agg ({n_train:>3} intersections)] {_fmt(train_agg_errors)}")
|
||||||
|
print(f"[val agg ({n_val:>3} intersections)] {_fmt(val_agg_errors)} ← generalization")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import random
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
INPUT_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
BASE_COLUMNS = ["intersection", "text_on_sign_exact", "latitude", "longitude"]
|
||||||
|
OPTIONAL_COLUMNS = ["code_type"]
|
||||||
EXCLUDED_INTERSECTIONS = {"56th-pena"}
|
EXCLUDED_INTERSECTIONS = {"56th-pena"}
|
||||||
|
|
||||||
|
|
||||||
@ -31,8 +33,11 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bag-size",
|
"--bag-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=8,
|
||||||
help="Number of sign texts to sample for each training row.",
|
help=(
|
||||||
|
"Total number of sign texts per training row. Street signs fill "
|
||||||
|
"slots first; remaining slots are randomly sampled from all signs."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--samples-per-intersection",
|
"--samples-per-intersection",
|
||||||
@ -50,12 +55,13 @@ def parse_args():
|
|||||||
|
|
||||||
def load_raw_data(path):
|
def load_raw_data(path):
|
||||||
data = pd.read_csv(path)
|
data = pd.read_csv(path)
|
||||||
missing = sorted(set(INPUT_COLUMNS) - set(data.columns))
|
missing = sorted(set(BASE_COLUMNS) - set(data.columns))
|
||||||
if missing:
|
if missing:
|
||||||
raise ValueError(f"Missing required columns: {missing}")
|
raise ValueError(f"Missing required columns: {missing}")
|
||||||
|
|
||||||
data = data[INPUT_COLUMNS].copy()
|
keep = BASE_COLUMNS + [c for c in OPTIONAL_COLUMNS if c in data.columns]
|
||||||
data = data.dropna(subset=INPUT_COLUMNS)
|
data = data[keep].copy()
|
||||||
|
data = data.dropna(subset=BASE_COLUMNS)
|
||||||
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
|
data["text_on_sign_exact"] = data["text_on_sign_exact"].astype(str).str.strip()
|
||||||
data = data[data["text_on_sign_exact"] != ""]
|
data = data[data["text_on_sign_exact"] != ""]
|
||||||
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
|
data = data[~data["intersection"].isin(EXCLUDED_INTERSECTIONS)]
|
||||||
@ -68,28 +74,48 @@ def make_bootstrap_data(data, bag_size, samples_per_intersection, seed, separato
|
|||||||
if samples_per_intersection < 1:
|
if samples_per_intersection < 1:
|
||||||
raise ValueError("--samples-per-intersection must be at least 1")
|
raise ValueError("--samples-per-intersection must be at least 1")
|
||||||
|
|
||||||
|
has_code_type = "code_type" in data.columns
|
||||||
rows = []
|
rows = []
|
||||||
grouped = data.groupby("intersection", sort=True)
|
grouped = data.groupby("intersection", sort=True)
|
||||||
for intersection, group in grouped:
|
for intersection, group in grouped:
|
||||||
texts = group["text_on_sign_exact"].tolist()
|
|
||||||
latitude = group["latitude"].mean()
|
latitude = group["latitude"].mean()
|
||||||
longitude = group["longitude"].mean()
|
longitude = group["longitude"].mean()
|
||||||
|
all_texts = group["text_on_sign_exact"].tolist()
|
||||||
|
|
||||||
|
# Street signs are the most geographically discriminative; always
|
||||||
|
# include them first up to bag_size, then fill remaining slots randomly.
|
||||||
|
if has_code_type:
|
||||||
|
street_texts = group[group["code_type"] == "street_sign"][
|
||||||
|
"text_on_sign_exact"
|
||||||
|
].tolist()
|
||||||
|
else:
|
||||||
|
street_texts = []
|
||||||
|
|
||||||
for sample_id in range(samples_per_intersection):
|
for sample_id in range(samples_per_intersection):
|
||||||
sampled = group.sample(
|
rng_state = seed + len(rows)
|
||||||
n=bag_size,
|
|
||||||
replace=True,
|
guaranteed = street_texts[:bag_size]
|
||||||
random_state=seed + len(rows),
|
remaining = bag_size - len(guaranteed)
|
||||||
)
|
if remaining > 0:
|
||||||
|
filler = group.sample(
|
||||||
|
n=remaining, replace=True, random_state=rng_state
|
||||||
|
)["text_on_sign_exact"].tolist()
|
||||||
|
else:
|
||||||
|
filler = []
|
||||||
|
|
||||||
|
bag = guaranteed + filler
|
||||||
|
# Shuffle so position in the sequence doesn't encode sign type.
|
||||||
|
random.Random(rng_state).shuffle(bag)
|
||||||
|
|
||||||
rows.append(
|
rows.append(
|
||||||
{
|
{
|
||||||
"intersection": intersection,
|
"intersection": intersection,
|
||||||
"sample_id": sample_id,
|
"sample_id": sample_id,
|
||||||
"text": separator.join(sampled["text_on_sign_exact"].tolist()),
|
"text": separator.join(bag),
|
||||||
"latitude": latitude,
|
"latitude": latitude,
|
||||||
"longitude": longitude,
|
"longitude": longitude,
|
||||||
"unique_sign_count": len(set(texts)),
|
"unique_sign_count": len(set(all_texts)),
|
||||||
"raw_sign_count": len(texts),
|
"raw_sign_count": len(all_texts),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
33
train.py
33
train.py
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import LoggingHandler, SentenceTransformer
|
from sentence_transformers import LoggingHandler, SentenceTransformer
|
||||||
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
from data_utils import split_indices
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
@ -91,9 +91,9 @@ def parse_args():
|
|||||||
"epoch). Validation still runs every epoch."
|
"epoch). Validation still runs every epoch."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
||||||
parser.add_argument("--head-learning-rate", type=float, default=1e-3)
|
parser.add_argument("--head-learning-rate", type=float, default=5e-2)
|
||||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
parser.add_argument("--weight-decay", type=float, default=0.001)
|
||||||
parser.add_argument("--test-size", type=float, default=0.2)
|
parser.add_argument("--test-size", type=float, default=0.2)
|
||||||
parser.add_argument("--hidden-dim", type=int, default=256)
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||||
parser.add_argument("--dropout", type=float, default=0.1)
|
parser.add_argument("--dropout", type=float, default=0.1)
|
||||||
@ -433,18 +433,9 @@ def main():
|
|||||||
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
coordinates = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
||||||
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
|
normalized_coordinates, coord_mean, coord_std = normalize_coordinates(coordinates)
|
||||||
|
|
||||||
indices = np.arange(len(data))
|
train_indices, val_indices = split_indices(
|
||||||
if "intersection" in data.columns:
|
data, test_size=args.test_size, seed=args.seed
|
||||||
splitter = GroupShuffleSplit(
|
)
|
||||||
n_splits=1, test_size=args.test_size, random_state=args.seed
|
|
||||||
)
|
|
||||||
train_indices, val_indices = next(
|
|
||||||
splitter.split(indices, groups=data["intersection"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
train_indices, val_indices = train_test_split(
|
|
||||||
indices, test_size=args.test_size, random_state=args.seed
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = SignCoordinateDataset(
|
train_dataset = SignCoordinateDataset(
|
||||||
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
|
[texts[i] for i in train_indices], normalized_coordinates[train_indices]
|
||||||
@ -515,7 +506,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer = make_optimizer(encoder, head, args)
|
optimizer = make_optimizer(encoder, head, args)
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.HuberLoss()
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="min", patience=5, factor=0.5, min_lr=1e-7
|
||||||
|
)
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
best_states = None
|
best_states = None
|
||||||
pending_save = False
|
pending_save = False
|
||||||
@ -552,9 +546,12 @@ def main():
|
|||||||
coord_std,
|
coord_std,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
scheduler.step(val_loss)
|
||||||
|
current_lr = optimizer.param_groups[-1]["lr"]
|
||||||
print(
|
print(
|
||||||
f"epoch={epoch} train_loss={train_loss:.6f} "
|
f"epoch={epoch} train_loss={train_loss:.6f} "
|
||||||
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f}"
|
f"val_loss={val_loss:.6f} val_error_km={val_error_km:.3f} "
|
||||||
|
f"lr={current_lr:.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user