device fix (cuda)
This commit is contained in:
parent
40967a303c
commit
b72cd1b917
21
eval.py
21
eval.py
@ -66,6 +66,7 @@ def load_model(model_path, device):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
trained_encoder = SentenceTransformer(model_path, device=str(device))
|
trained_encoder = SentenceTransformer(model_path, device=str(device))
|
||||||
|
trained_encoder.to(device)
|
||||||
trained_head = load_head(
|
trained_head = load_head(
|
||||||
model_path, "coordinate_head.pt", trained_encoder, config, device
|
model_path, "coordinate_head.pt", trained_encoder, config, device
|
||||||
)
|
)
|
||||||
@ -76,6 +77,7 @@ def load_model(model_path, device):
|
|||||||
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
|
base_encoder = SentenceTransformer(initial_encoder_path, device=str(device))
|
||||||
else:
|
else:
|
||||||
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
|
base_encoder = SentenceTransformer(config["model_name"], device=str(device))
|
||||||
|
base_encoder.to(device)
|
||||||
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
|
initial_head_path = os.path.join(model_path, "initial_coordinate_head.pt")
|
||||||
initial_head = None
|
initial_head = None
|
||||||
if os.path.exists(initial_head_path):
|
if os.path.exists(initial_head_path):
|
||||||
@ -120,16 +122,9 @@ def make_prediction_plot(results, plot_file):
|
|||||||
color="0.75",
|
color="0.75",
|
||||||
linewidth=0.4,
|
linewidth=0.4,
|
||||||
alpha=0.35,
|
alpha=0.35,
|
||||||
|
zorder=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.scatter(
|
|
||||||
results["longitude"],
|
|
||||||
results["latitude"],
|
|
||||||
s=18,
|
|
||||||
color="#1f77b4",
|
|
||||||
alpha=0.75,
|
|
||||||
label="actual",
|
|
||||||
)
|
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
results["predicted_longitude"],
|
results["predicted_longitude"],
|
||||||
results["predicted_latitude"],
|
results["predicted_latitude"],
|
||||||
@ -137,6 +132,16 @@ def make_prediction_plot(results, plot_file):
|
|||||||
color="#d62728",
|
color="#d62728",
|
||||||
alpha=0.45,
|
alpha=0.45,
|
||||||
label="predicted",
|
label="predicted",
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
ax.scatter(
|
||||||
|
results["longitude"],
|
||||||
|
results["latitude"],
|
||||||
|
s=18,
|
||||||
|
color="#1f77b4",
|
||||||
|
alpha=0.75,
|
||||||
|
label="actual",
|
||||||
|
zorder=3,
|
||||||
)
|
)
|
||||||
ax.set_xlabel("longitude")
|
ax.set_xlabel("longitude")
|
||||||
ax.set_ylabel("latitude")
|
ax.set_ylabel("latitude")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user