From b72cd1b9170fe625ae4a15fcefb8a4d301bdb5e9 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 25 May 2026 22:10:06 +0000 Subject: [PATCH] device fix (cuda) --- eval.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/eval.py b/eval.py index 3ebbe60..9d57c82 100644 --- a/eval.py +++ b/eval.py @@ -66,6 +66,7 @@ def load_model(model_path, device): config = json.load(f) trained_encoder = SentenceTransformer(model_path, device=str(device)) + trained_encoder.to(device) trained_head = load_head( model_path, "coordinate_head.pt", trained_encoder, config, device ) @@ -76,6 +77,7 @@ def load_model(model_path, device): base_encoder = SentenceTransformer(initial_encoder_path, device=str(device)) else: base_encoder = SentenceTransformer(config["model_name"], device=str(device)) + base_encoder.to(device) initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt") initial_head = None if os.path.exists(initial_head_path): @@ -120,16 +122,9 @@ def make_prediction_plot(results, plot_file): color="0.75", linewidth=0.4, alpha=0.35, + zorder=1, ) - ax.scatter( - results["longitude"], - results["latitude"], - s=18, - color="#1f77b4", - alpha=0.75, - label="actual", - ) ax.scatter( results["predicted_longitude"], results["predicted_latitude"], @@ -137,6 +132,16 @@ def make_prediction_plot(results, plot_file): color="#d62728", alpha=0.45, label="predicted", + zorder=2, + ) + ax.scatter( + results["longitude"], + results["latitude"], + s=18, + color="#1f77b4", + alpha=0.75, + label="actual", + zorder=3, ) ax.set_xlabel("longitude") ax.set_ylabel("latitude")