device fix (cuda)
This commit is contained in:
parent
40967a303c
commit
b72cd1b917
21
eval.py
21
eval.py
@ -66,6 +66,7 @@ def load_model(model_path, device):
|
||||
config = json.load(f)
|
||||
|
||||
trained_encoder = SentenceTransformer(model_path, device=str(device))
|
||||
trained_encoder.to(device)
|
||||
trained_head = load_head(
|
||||
model_path, "coordinate_head.pt", trained_encoder, config, device
|
||||
)
|
||||
@ -76,6 +77,7 @@ def load_model(model_path, device):
|
||||
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
|
||||
else:
|
||||
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
|
||||
base_encoder.to(device)
|
||||
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
|
||||
initial_head = None
|
||||
if os.path.exists(initial_head_path):
|
||||
@ -120,16 +122,9 @@ def make_prediction_plot(results, plot_file):
|
||||
color="0.75",
|
||||
linewidth=0.4,
|
||||
alpha=0.35,
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.scatter(
|
||||
results["longitude"],
|
||||
results["latitude"],
|
||||
s=18,
|
||||
color="#1f77b4",
|
||||
alpha=0.75,
|
||||
label="actual",
|
||||
)
|
||||
ax.scatter(
|
||||
results["predicted_longitude"],
|
||||
results["predicted_latitude"],
|
||||
@ -137,6 +132,16 @@ def make_prediction_plot(results, plot_file):
|
||||
color="#d62728",
|
||||
alpha=0.45,
|
||||
label="predicted",
|
||||
zorder=2,
|
||||
)
|
||||
ax.scatter(
|
||||
results["longitude"],
|
||||
results["latitude"],
|
||||
s=18,
|
||||
color="#1f77b4",
|
||||
alpha=0.75,
|
||||
label="actual",
|
||||
zorder=3,
|
||||
)
|
||||
ax.set_xlabel("longitude")
|
||||
ax.set_ylabel("latitude")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user