This commit is contained in:
Michael Pilosov 2026-05-25 21:25:24 +00:00
parent b5810dd282
commit 44c7753856
3 changed files with 12 additions and 3 deletions

View File

@ -69,7 +69,7 @@ make
- `--model-name`: sentence-transformers base model. - `--model-name`: sentence-transformers base model.
- `--epochs`: training epochs. - `--epochs`: training epochs.
- `--batch-size`: batch size. - `--batch-size`: batch size.
- `--device`: explicit device such as `cpu`, `cuda`, or `mps`. - `--device`: device override such as `cpu`, `cuda`, or `mps`. Defaults to `cuda`.
## Outputs ## Outputs

View File

@ -28,7 +28,11 @@ def parse_args():
default="plots/predicted_vs_actual.png", default="plots/predicted_vs_actual.png",
help="2x1 scatter plot of predicted vs actual latitude and longitude.", help="2x1 scatter plot of predicted vs actual latitude and longitude.",
) )
parser.add_argument("--device", default=None) parser.add_argument(
"--device",
default="cuda",
help="Device to use for evaluation. Defaults to `cuda`.",
)
parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=64)
return parser.parse_args() return parser.parse_args()

View File

@ -68,7 +68,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2" "--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
) )
parser.add_argument("--device", default=None) parser.add_argument(
"--device",
default="cuda",
help="Device to use for training. Defaults to `cuda`.",
)
parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--seed", type=int, default=1992)
parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
@ -341,6 +345,7 @@ def main():
) )
encoder = SentenceTransformer(args.model_name, device=str(device)) encoder = SentenceTransformer(args.model_name, device=str(device))
encoder.to(device)
embedding_dim = encoder.get_sentence_embedding_dimension() embedding_dim = encoder.get_sentence_embedding_dimension()
head = CoordinateRegressor( head = CoordinateRegressor(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,