265 lines
8.1 KiB
Python
265 lines
8.1 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
|
|
os.environ.setdefault("MPLCONFIGDIR", "/private/tmp/matplotlib")
|
|
os.environ.setdefault("XDG_CACHE_HOME", "/private/tmp")
|
|
|
|
import matplotlib # noqa: E402
|
|
import numpy as np # noqa: E402
|
|
import pandas as pd # noqa: E402
|
|
import torch # noqa: E402
|
|
from sentence_transformers import SentenceTransformer # noqa: E402
|
|
|
|
from train import CoordinateRegressor, haversine_km # noqa: E402
|
|
|
|
matplotlib.use("Agg")
|
|
from matplotlib import pyplot as plt # noqa: E402
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Evaluate sign coordinate model.")
|
|
parser.add_argument("--data-file", default="training.csv")
|
|
parser.add_argument("--model-path", default="output")
|
|
parser.add_argument("--output-file", default="predictions.csv")
|
|
parser.add_argument("--plot-file", default="plots/prediction_map.png")
|
|
parser.add_argument(
|
|
"--scatter-plot-file",
|
|
default="plots/predicted_vs_actual.png",
|
|
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
default="cuda",
|
|
help="Device to use for evaluation. Defaults to `cuda`.",
|
|
)
|
|
parser.add_argument("--batch-size", type=int, default=64)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_device(requested_device):
|
|
if requested_device:
|
|
return torch.device(requested_device)
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
return torch.device("cpu")
|
|
|
|
|
|
def load_head(model_path, filename, encoder, config, device):
|
|
head = CoordinateRegressor(
|
|
embedding_dim=encoder.get_sentence_embedding_dimension(),
|
|
hidden_dim=config["hidden_dim"],
|
|
dropout=config["dropout"],
|
|
).to(device)
|
|
head.load_state_dict(
|
|
torch.load(
|
|
os.path.join(model_path, filename),
|
|
map_location=device,
|
|
)
|
|
)
|
|
head.eval()
|
|
return head
|
|
|
|
|
|
def load_model(model_path, device):
|
|
with open(os.path.join(model_path, "coordinate_config.json")) as f:
|
|
config = json.load(f)
|
|
|
|
trained_encoder = SentenceTransformer(model_path, device=str(device))
|
|
trained_encoder.to(device)
|
|
trained_head = load_head(
|
|
model_path, "coordinate_head.pt", trained_encoder, config, device
|
|
)
|
|
trained_encoder.eval()
|
|
|
|
initial_encoder_path = os.path.join(model_path, "initial_encoder")
|
|
if os.path.exists(initial_encoder_path):
|
|
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
|
|
else:
|
|
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
|
|
base_encoder.to(device)
|
|
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
|
|
initial_head = None
|
|
if os.path.exists(initial_head_path):
|
|
initial_head = load_head(
|
|
model_path, "initial_coordinate_head.pt", base_encoder, config, device
|
|
)
|
|
base_encoder.eval()
|
|
|
|
return (
|
|
base_encoder,
|
|
initial_head,
|
|
trained_encoder,
|
|
trained_head,
|
|
np.array(config["coord_mean"]),
|
|
np.array(config["coord_std"]),
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def predict(encoder, head, texts, coord_mean, coord_std, device, batch_size):
|
|
predictions = []
|
|
for start in range(0, len(texts), batch_size):
|
|
batch = texts[start : start + batch_size]
|
|
features = encoder.tokenize(batch)
|
|
features = {key: value.to(device) for key, value in features.items()}
|
|
embeddings = encoder(features)["sentence_embedding"]
|
|
batch_predictions = head(embeddings).cpu().numpy()
|
|
predictions.append(batch_predictions)
|
|
|
|
normalized = np.vstack(predictions)
|
|
return normalized * coord_std + coord_mean
|
|
|
|
|
|
def make_prediction_plot(results, plot_file):
|
|
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
|
|
fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
for _, row in results.iterrows():
|
|
ax.plot(
|
|
[row["longitude"], row["predicted_longitude"]],
|
|
[row["latitude"], row["predicted_latitude"]],
|
|
color="0.75",
|
|
linewidth=0.4,
|
|
alpha=0.35,
|
|
zorder=1,
|
|
)
|
|
|
|
ax.scatter(
|
|
results["predicted_longitude"],
|
|
results["predicted_latitude"],
|
|
s=12,
|
|
color="#d62728",
|
|
alpha=0.45,
|
|
label="predicted",
|
|
zorder=2,
|
|
)
|
|
ax.scatter(
|
|
results["longitude"],
|
|
results["latitude"],
|
|
s=18,
|
|
color="#1f77b4",
|
|
alpha=0.75,
|
|
label="actual",
|
|
zorder=3,
|
|
)
|
|
ax.set_xlabel("longitude")
|
|
ax.set_ylabel("latitude")
|
|
ax.set_title("Sign Text Coordinate Predictions")
|
|
ax.legend()
|
|
ax.set_aspect("equal", adjustable="box")
|
|
fig.tight_layout()
|
|
fig.savefig(plot_file, dpi=200)
|
|
plt.close(fig)
|
|
|
|
|
|
def make_predicted_vs_actual_plot(results, plot_file):
|
|
os.makedirs(os.path.dirname(plot_file), exist_ok=True)
|
|
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(7, 9))
|
|
|
|
plot_specs = [
|
|
("latitude", "initial_predicted_latitude", "predicted_latitude", "Latitude"),
|
|
(
|
|
"longitude",
|
|
"initial_predicted_longitude",
|
|
"predicted_longitude",
|
|
"Longitude",
|
|
),
|
|
]
|
|
for ax, (actual_col, initial_col, predicted_col, title) in zip(axes, plot_specs):
|
|
actual = results[actual_col]
|
|
initial = results[initial_col] if initial_col in results else None
|
|
predicted = results[predicted_col]
|
|
values = [actual, predicted]
|
|
if initial is not None:
|
|
values.append(initial)
|
|
lower = min(series.min() for series in values)
|
|
upper = max(series.max() for series in values)
|
|
|
|
if initial is not None:
|
|
ax.scatter(
|
|
actual,
|
|
initial,
|
|
s=14,
|
|
alpha=0.45,
|
|
color="#d62728",
|
|
label="before training",
|
|
)
|
|
ax.scatter(
|
|
actual,
|
|
predicted,
|
|
s=14,
|
|
alpha=0.55,
|
|
color="#1f77b4",
|
|
label="after training",
|
|
)
|
|
ax.plot([lower, upper], [lower, upper], color="0.25", linewidth=1)
|
|
ax.set_xlabel(f"actual {actual_col}")
|
|
ax.set_ylabel(f"predicted {actual_col}")
|
|
ax.set_title(title)
|
|
ax.set_xlim(lower, upper)
|
|
ax.set_ylim(lower, upper)
|
|
ax.legend()
|
|
|
|
fig.tight_layout()
|
|
fig.savefig(plot_file, dpi=200)
|
|
plt.close(fig)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
device = get_device(args.device)
|
|
(
|
|
base_encoder,
|
|
initial_head,
|
|
trained_encoder,
|
|
trained_head,
|
|
coord_mean,
|
|
coord_std,
|
|
) = load_model(args.model_path, device)
|
|
|
|
data = pd.read_csv(args.data_file).dropna(subset=["text", "latitude", "longitude"])
|
|
texts = data["text"].astype(str).tolist()
|
|
predicted = predict(
|
|
trained_encoder,
|
|
trained_head,
|
|
texts,
|
|
coord_mean,
|
|
coord_std,
|
|
device,
|
|
args.batch_size,
|
|
)
|
|
true_coords = data[["latitude", "longitude"]].to_numpy(dtype=np.float32)
|
|
errors = haversine_km(predicted, true_coords)
|
|
|
|
results = data.copy()
|
|
results["predicted_latitude"] = predicted[:, 0]
|
|
results["predicted_longitude"] = predicted[:, 1]
|
|
if initial_head is not None:
|
|
initial_predicted = predict(
|
|
base_encoder,
|
|
initial_head,
|
|
texts,
|
|
coord_mean,
|
|
coord_std,
|
|
device,
|
|
args.batch_size,
|
|
)
|
|
results["initial_predicted_latitude"] = initial_predicted[:, 0]
|
|
results["initial_predicted_longitude"] = initial_predicted[:, 1]
|
|
results["error_km"] = errors
|
|
results.to_csv(args.output_file, index=False)
|
|
make_prediction_plot(results, args.plot_file)
|
|
make_predicted_vs_actual_plot(results, args.scatter_plot_file)
|
|
|
|
print(f"Wrote predictions to {args.output_file}")
|
|
print(f"Wrote plot to {args.plot_file}")
|
|
print(f"Wrote scatter plot to {args.scatter_plot_file}")
|
|
print(f"mean_error_km={errors.mean():.3f}")
|
|
print(f"median_error_km={np.median(errors):.3f}")
|
|
print(f"p90_error_km={np.percentile(errors, 90):.3f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|