cuda
This commit is contained in:
parent
b5810dd282
commit
44c7753856
@ -69,7 +69,7 @@ make
|
||||
- `--model-name`: sentence-transformers base model.
|
||||
- `--epochs`: training epochs.
|
||||
- `--batch-size`: batch size.
|
||||
- `--device`: explicit device such as `cpu`, `cuda`, or `mps`.
|
||||
- `--device`: device override such as `cpu`, `cuda`, or `mps`. Defaults to `cuda`.
|
||||
|
||||
## Outputs
|
||||
|
||||
|
||||
6
eval.py
6
eval.py
@ -28,7 +28,11 @@ def parse_args():
|
||||
default="plots/predicted_vs_actual.png",
|
||||
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
|
||||
)
|
||||
parser.add_argument("--device", default=None)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda",
|
||||
help="Device to use for evaluation. Defaults to `cuda`.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
7
train.py
7
train.py
@ -68,7 +68,11 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
parser.add_argument("--device", default=None)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda",
|
||||
help="Device to use for training. Defaults to `cuda`.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1992)
|
||||
parser.add_argument("--epochs", type=int, default=10)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
@ -341,6 +345,7 @@ def main():
|
||||
)
|
||||
|
||||
encoder = SentenceTransformer(args.model_name, device=str(device))
|
||||
encoder.to(device)
|
||||
embedding_dim = encoder.get_sentence_embedding_dimension()
|
||||
head = CoordinateRegressor(
|
||||
embedding_dim=embedding_dim,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user