From 44c775385603e1b24b07419eedbc64ef28f2fde3 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 25 May 2026 21:25:24 +0000 Subject: [PATCH] cuda --- README.md | 2 +- eval.py | 6 +++++- train.py | 7 ++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index abc887f..c94c6ce 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ make - `--model-name`: sentence-transformers base model. - `--epochs`: training epochs. - `--batch-size`: batch size. -- `--device`: explicit device such as `cpu`, `cuda`, or `mps`. +- `--device`: device override such as `cpu`, `cuda`, or `mps`. Defaults to `cuda`. ## Outputs diff --git a/eval.py b/eval.py index 932e547..3ebbe60 100644 --- a/eval.py +++ b/eval.py @@ -28,7 +28,11 @@ def parse_args(): default="plots/predicted_vs_actual.png", help="2x1 scatter plot of predicted vs actual latitude and longitude.", ) - parser.add_argument("--device", default=None) + parser.add_argument( + "--device", + default="cuda", + help="Device to use for evaluation. Defaults to `cuda`.", + ) parser.add_argument("--batch-size", type=int, default=64) return parser.parse_args() diff --git a/train.py b/train.py index e56689f..37be00e 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,11 @@ def parse_args(): parser.add_argument( "--model-name", default="sentence-transformers/all-MiniLM-L6-v2" ) - parser.add_argument("--device", default=None) + parser.add_argument( + "--device", + default="cuda", + help="Device to use for training. Defaults to `cuda`.", + ) parser.add_argument("--seed", type=int, default=1992) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch-size", type=int, default=32) @@ -341,6 +345,7 @@ def main(): ) encoder = SentenceTransformer(args.model_name, device=str(device)) + encoder.to(device) embedding_dim = encoder.get_sentence_embedding_dimension() head = CoordinateRegressor( embedding_dim=embedding_dim,