citybert/eval.py
Michael Pilosov bfad4547ce no mps
2026-05-25 14:26:37 -06:00

256 lines
7.9 KiB
Python

import argparse
import json
import os
os.environ.setdefault("MPLCONFIGDIR", "/private/tmp/matplotlib")
os.environ.setdefault("XDG_CACHE_HOME", "/private/tmp")
import matplotlib # noqa: E402
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
import torch # noqa: E402
from sentence_transformers import SentenceTransformer # noqa: E402
from train import CoordinateRegressor, haversine_km # noqa: E402
matplotlib.use("Agg")
from matplotlib import pyplot as plt # noqa: E402
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate sign coordinate model.")
parser.add_argument("--data-file", default="training.csv")
parser.add_argument("--model-path", default="output")
parser.add_argument("--output-file", default="predictions.csv")
parser.add_argument("--plot-file", default="plots/prediction_map.png")
parser.add_argument(
"--scatter-plot-file",
default="plots/predicted_vs_actual.png",
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
)
parser.add_argument("--device", default=None)
parser.add_argument("--batch-size", type=int, default=64)
return parser.parse_args()
def get_device(requested_device):
if requested_device:
return torch.device(requested_device)
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_head(model_path, filename, encoder, config, device):
head = CoordinateRegressor(
embedding_dim=encoder.get_sentence_embedding_dimension(),
hidden_dim=config["hidden_dim"],
dropout=config["dropout"],
).to(device)
head.load_state_dict(
torch.load(
os.path.join(model_path, filename),
map_location=device,
)
)
head.eval()
return head
def load_model(model_path, device):
with open(os.path.join(model_path, "coordinate_config.json")) as f:
config = json.load(f)
trained_encoder = SentenceTransformer(model_path, device=str(device))
trained_head = load_head(
model_path, "coordinate_head.pt", trained_encoder, config, device
)
trained_encoder.eval()
initial_encoder_path = os.path.join(model_path, "initial_encoder")
if os.path.exists(initial_encoder_path):
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
else:
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
initial_head = None
if os.path.exists(initial_head_path):
initial_head = load_head(
model_path, "initial_coordinate_head.pt", base_encoder, config, device
)
base_encoder.eval()
return (
base_encoder,
initial_head,
trained_encoder,
trained_head,
np.array(config["coord_mean"]),
np.array(config["coord_std"]),
)
@torch.no_grad()
def predict(encoder, head, texts, coord_mean, coord_std, device, batch_size):
predictions = []
for start in range(0, len(texts), batch_size):
batch = texts[start : start + batch_size]
features = encoder.tokenize(batch)
features = {key: value.to(device) for key, value in features.items()}
embeddings = encoder(features)["sentence_embedding"]
batch_predictions = head(embeddings).cpu().numpy()
predictions.append(batch_predictions)
normalized = np.vstack(predictions)
return normalized * coord_std + coord_mean
def make_prediction_plot(results, plot_file):
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 8))
for _, row in results.iterrows():
ax.plot(
[row["longitude"], row["predicted_longitude"]],
[row["latitude"], row["predicted_latitude"]],
color="0.75",
linewidth=0.4,
alpha=0.35,
)
ax.scatter(
results["longitude"],
results["latitude"],
s=18,
color="#1f77b4",
alpha=0.75,
label="actual",
)
ax.scatter(
results["predicted_longitude"],
results["predicted_latitude"],
s=12,
color="#d62728",
alpha=0.45,
label="predicted",
)
ax.set_xlabel("longitude")
ax.set_ylabel("latitude")
ax.set_title("Sign Text Coordinate Predictions")
ax.legend()
ax.set_aspect("equal", adjustable="box")
fig.tight_layout()
fig.savefig(plot_file, dpi=200)
plt.close(fig)
def make_predicted_vs_actual_plot(results, plot_file):
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(7, 9))
plot_specs = [
("latitude", "initial_predicted_latitude", "predicted_latitude", "Latitude"),
(
"longitude",
"initial_predicted_longitude",
"predicted_longitude",
"Longitude",
),
]
for ax, (actual_col, initial_col, predicted_col, title) in zip(axes, plot_specs):
actual = results[actual_col]
initial = results[initial_col] if initial_col in results else None
predicted = results[predicted_col]
values = [actual, predicted]
if initial is not None:
values.append(initial)
lower = min(series.min() for series in values)
upper = max(series.max() for series in values)
if initial is not None:
ax.scatter(
actual,
initial,
s=14,
alpha=0.45,
color="#d62728",
label="before training",
)
ax.scatter(
actual,
predicted,
s=14,
alpha=0.55,
color="#1f77b4",
label="after training",
)
ax.plot([lower, upper], [lower, upper], color="0.25", linewidth=1)
ax.set_xlabel(f"actual {actual_col}")
ax.set_ylabel(f"predicted {actual_col}")
ax.set_title(title)
ax.set_xlim(lower, upper)
ax.set_ylim(lower, upper)
ax.legend()
fig.tight_layout()
fig.savefig(plot_file, dpi=200)
plt.close(fig)
def main():
args = parse_args()
device = get_device(args.device)
(
base_encoder,
initial_head,
trained_encoder,
trained_head,
coord_mean,
coord_std,
) = load_model(args.model_path, device)
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"])
texts = data["text"].astype(str).tolist()
predicted = predict(
trained_encoder,
trained_head,
texts,
coord_mean,
coord_std,
device,
args.batch_size,
)
true_coords = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
errors = haversine_km(predicted, true_coords)
results = data.copy()
results["predicted_latitude"] = predicted[:, 0]
results["predicted_longitude"] = predicted[:, 1]
if initial_head is not None:
initial_predicted = predict(
base_encoder,
initial_head,
texts,
coord_mean,
coord_std,
device,
args.batch_size,
)
results["initial_predicted_latitude"] = initial_predicted[:, 0]
results["initial_predicted_longitude"] = initial_predicted[:, 1]
results["error_km"] = errors
results.to_csv(args.output_file, index=False)
make_prediction_plot(results, args.plot_file)
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
print(f"Wrote predictions to {args.output_file}")
print(f"Wrote plot to {args.plot_file}")
print(f"Wrote scatter plot to {args.scatter_plot_file}")
print(f"mean_error_km={errors.mean():.3f}")
print(f"median_error_km={np.median(errors):.3f}")
print(f"p90_error_km={np.percentile(errors, 90):.3f}")
if __name__ == "__main__":
main()