cuda
This commit is contained in:
parent
b5810dd282
commit
44c7753856
@ -69,7 +69,7 @@ make
|
|||||||
- `--model-name`: sentence-transformers base model.
|
- `--model-name`: sentence-transformers base model.
|
||||||
- `--epochs`: training epochs.
|
- `--epochs`: training epochs.
|
||||||
- `--batch-size`: batch size.
|
- `--batch-size`: batch size.
|
||||||
- `--device`: explicit device such as `cpu`, `cuda`, or `mps`.
|
- `--device`: device override such as `cpu`, `cuda`, or `mps`. Defaults to `cuda`.
|
||||||
|
|
||||||
## Outputs
|
## Outputs
|
||||||
|
|
||||||
|
|||||||
6
eval.py
6
eval.py
@ -28,7 +28,11 @@ def parse_args():
|
|||||||
default="plots/predicted_vs_actual.png",
|
default="plots/predicted_vs_actual.png",
|
||||||
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
|
help="2x1 scatter plot of predicted vs actual latitude and longitude.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default=None)
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use for evaluation. Defaults to `cuda`.",
|
||||||
|
)
|
||||||
parser.add_argument("--batch-size", type=int, default=64)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
7
train.py
7
train.py
@ -68,7 +68,11 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
|
"--model-name", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default=None)
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use for training. Defaults to `cuda`.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=1992)
|
parser.add_argument("--seed", type=int, default=1992)
|
||||||
parser.add_argument("--epochs", type=int, default=10)
|
parser.add_argument("--epochs", type=int, default=10)
|
||||||
parser.add_argument("--batch-size", type=int, default=32)
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
@ -341,6 +345,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder = SentenceTransformer(args.model_name, device=str(device))
|
encoder = SentenceTransformer(args.model_name, device=str(device))
|
||||||
|
encoder.to(device)
|
||||||
embedding_dim = encoder.get_sentence_embedding_dimension()
|
embedding_dim = encoder.get_sentence_embedding_dimension()
|
||||||
head = CoordinateRegressor(
|
head = CoordinateRegressor(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user