import argparse import json import os os.environ.setdefault("MPLCONFIGDIR", "/private/tmp/matplotlib") os.environ.setdefault("XDG_CACHE_HOME", "/private/tmp") import matplotlib # noqa: E402 import numpy as np # noqa: E402 import pandas as pd # noqa: E402 import torch # noqa: E402 from sentence_transformers import SentenceTransformer # noqa: E402 from train import CoordinateRegressor, haversine_km # noqa: E402 matplotlib.use("Agg") from matplotlib import pyplot as plt # noqa: E402 def parse_args(): parser = argparse.ArgumentParser(description="Evaluate sign coordinate model.") parser.add_argument("--data-file", default="training.csv") parser.add_argument("--model-path", default="output") parser.add_argument("--output-file", default="predictions.csv") parser.add_argument("--plot-file", default="plots/prediction_map.png") parser.add_argument( "--scatter-plot-file", default="plots/predicted_vs_actual.png", help="2x1 scatter plot of predicted vs actual latitude and longitude.", ) parser.add_argument( "--device", default="cuda", help="Device to use for evaluation. Defaults to `cuda`.", ) parser.add_argument("--batch-size", type=int, default=64) return parser.parse_args() def get_device(requested_device): if requested_device: return torch.device(requested_device) if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def load_head(model_path, filename, encoder, config, device): head = CoordinateRegressor( embedding_dim=encoder.get_sentence_embedding_dimension(), hidden_dim=config["hidden_dim"], dropout=config["dropout"], ).to(device) head.load_state_dict( torch.load( os.path.join(model_path, filename), map_location=device, ) ) head.eval() return head def load_model(model_path, device): with open(os.path.join(model_path, "coordinate_config.json")) as f: config = json.load(f) trained_encoder = SentenceTransformer(model_path, device=str(device)) trained_head = load_head( model_path, "coordinate_head.pt", trained_encoder, config, device ) trained_encoder.eval() initial_encoder_path = os.path.join(model_path, "initial_encoder") if os.path.exists(initial_encoder_path): base_encoder = SentenceTransformer(initial_encoder_path, device=str(device)) else: base_encoder = SentenceTransformer(config["model_name"], device=str(device)) initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt") initial_head = None if os.path.exists(initial_head_path): initial_head = load_head( model_path, "initial_coordinate_head.pt", base_encoder, config, device ) base_encoder.eval() return ( base_encoder, initial_head, trained_encoder, trained_head, np.array(config["coord_mean"]), np.array(config["coord_std"]), ) @torch.no_grad() def predict(encoder, head, texts, coord_mean, coord_std, device, batch_size): predictions = [] for start in range(0, len(texts), batch_size): batch = texts[start : start + batch_size] features = encoder.tokenize(batch) features = {key: value.to(device) for key, value in features.items()} embeddings = encoder(features)["sentence_embedding"] batch_predictions = head(embeddings).cpu().numpy() predictions.append(batch_predictions) normalized = np.vstack(predictions) return normalized * coord_std + coord_mean def make_prediction_plot(results, plot_file): os.makedirs(os.path.dirname(plot_file), exist_ok=True) fig, ax = plt.subplots(figsize=(8, 8)) for _, row in results.iterrows(): ax.plot( [row["longitude"], row["predicted_longitude"]], [row["latitude"], row["predicted_latitude"]], color="0.75", linewidth=0.4, alpha=0.35, ) ax.scatter( results["longitude"], results["latitude"], s=18, color="#1f77b4", alpha=0.75, label="actual", ) ax.scatter( results["predicted_longitude"], results["predicted_latitude"], s=12, color="#d62728", alpha=0.45, label="predicted", ) ax.set_xlabel("longitude") ax.set_ylabel("latitude") ax.set_title("Sign Text Coordinate Predictions") ax.legend() ax.set_aspect("equal", adjustable="box") fig.tight_layout() fig.savefig(plot_file, dpi=200) plt.close(fig) def make_predicted_vs_actual_plot(results, plot_file): os.makedirs(os.path.dirname(plot_file), exist_ok=True) fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(7, 9)) plot_specs = [ ("latitude", "initial_predicted_latitude", "predicted_latitude", "Latitude"), ( "longitude", "initial_predicted_longitude", "predicted_longitude", "Longitude", ), ] for ax, (actual_col, initial_col, predicted_col, title) in zip(axes, plot_specs): actual = results[actual_col] initial = results[initial_col] if initial_col in results else None predicted = results[predicted_col] values = [actual, predicted] if initial is not None: values.append(initial) lower = min(series.min() for series in values) upper = max(series.max() for series in values) if initial is not None: ax.scatter( actual, initial, s=14, alpha=0.45, color="#d62728", label="before training", ) ax.scatter( actual, predicted, s=14, alpha=0.55, color="#1f77b4", label="after training", ) ax.plot([lower, upper], [lower, upper], color="0.25", linewidth=1) ax.set_xlabel(f"actual {actual_col}") ax.set_ylabel(f"predicted {actual_col}") ax.set_title(title) ax.set_xlim(lower, upper) ax.set_ylim(lower, upper) ax.legend() fig.tight_layout() fig.savefig(plot_file, dpi=200) plt.close(fig) def main(): args = parse_args() device = get_device(args.device) ( base_encoder, initial_head, trained_encoder, trained_head, coord_mean, coord_std, ) = load_model(args.model_path, device) data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"]) texts = data["text"].astype(str).tolist() predicted = predict( trained_encoder, trained_head, texts, coord_mean, coord_std, device, args.batch_size, ) true_coords = data[["latitude", "longitude"]].to_numpy(dtype=np.float32) errors = haversine_km(predicted, true_coords) results = data.copy() results["predicted_latitude"] = predicted[:, 0] results["predicted_longitude"] = predicted[:, 1] if initial_head is not None: initial_predicted = predict( base_encoder, initial_head, texts, coord_mean, coord_std, device, args.batch_size, ) results["initial_predicted_latitude"] = initial_predicted[:, 0] results["initial_predicted_longitude"] = initial_predicted[:, 1] results["error_km"] = errors results.to_csv(args.output_file, index=False) make_prediction_plot(results, args.plot_file) make_predicted_vs_actual_plot(results, args.scatter_plot_file) print(f"Wrote predictions to {args.output_file}") print(f"Wrote plot to {args.plot_file}") print(f"Wrote scatter plot to {args.scatter_plot_file}") print(f"mean_error_km={errors.mean():.3f}") print(f"median_error_km={np.median(errors):.3f}") print(f"p90_error_km={np.percentile(errors, 90):.3f}") if __name__ == "__main__": main()