diff --git a/eval.py b/eval.py index 3ebbe60..9d57c82 100644 --- a/eval.py +++ b/eval.py @@ -66,6 +66,7 @@ def load_model(model_path, device): config = json.load(f) trained_encoder = SentenceTransformer(model_path, device=str(device)) + trained_encoder.to(device) trained_head = load_head( model_path, "coordinate_head.pt", trained_encoder, config, device ) @@ -76,6 +77,7 @@ def load_model(model_path, device): base_encoder = SentenceTransformer(initial_encoder_path, device=str(device)) else: base_encoder = SentenceTransformer(config["model_name"], device=str(device)) + base_encoder.to(device) initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt") initial_head = None if os.path.exists(initial_head_path): @@ -120,16 +122,9 @@ def make_prediction_plot(results, plot_file): color="0.75", linewidth=0.4, alpha=0.35, + zorder=1, ) - ax.scatter( - results["longitude"], - results["latitude"], - s=18, - color="#1f77b4", - alpha=0.75, - label="actual", - ) ax.scatter( results["predicted_longitude"], results["predicted_latitude"], @@ -137,6 +132,16 @@ def make_prediction_plot(results, plot_file): color="#d62728", alpha=0.45, label="predicted", + zorder=2, + ) + ax.scatter( + results["longitude"], + results["latitude"], + s=18, + color="#1f77b4", + alpha=0.75, + label="actual", + zorder=3, ) ax.set_xlabel("longitude") ax.set_ylabel("latitude")