device fix (cuda)

This commit is contained in:
Michael Pilosov 2026-05-25 22:10:06 +00:00
parent 40967a303c
commit b72cd1b917

21
eval.py
View File

@ -66,6 +66,7 @@ def load_model(model_path, device):
config = json.load(f)
trained_encoder = SentenceTransformer(model_path, device=str(device))
trained_encoder.to(device)
trained_head = load_head(
model_path, "coordinate_head.pt", trained_encoder, config, device
)
@ -76,6 +77,7 @@ def load_model(model_path, device):
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
else:
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
base_encoder.to(device)
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
initial_head = None
if os.path.exists(initial_head_path):
@ -120,16 +122,9 @@ def make_prediction_plot(results, plot_file):
color="0.75",
linewidth=0.4,
alpha=0.35,
zorder=1,
)
ax.scatter(
results["longitude"],
results["latitude"],
s=18,
color="#1f77b4",
alpha=0.75,
label="actual",
)
ax.scatter(
results["predicted_longitude"],
results["predicted_latitude"],
@ -137,6 +132,16 @@ def make_prediction_plot(results, plot_file):
color="#d62728",
alpha=0.45,
label="predicted",
zorder=2,
)
ax.scatter(
results["longitude"],
results["latitude"],
s=18,
color="#1f77b4",
alpha=0.75,
label="actual",
zorder=3,
)
ax.set_xlabel("longitude")
ax.set_ylabel("latitude")