diff --git a/.gitignore b/.gitignore index 8184cf2..033ef8b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ plots* *.csv output/ .requirements_installed +__pycache__/ diff --git a/eval.py b/eval.py index a1edd2c..932e547 100644 --- a/eval.py +++ b/eval.py @@ -38,8 +38,6 @@ def get_device(requested_device): return torch.device(requested_device) if torch.cuda.is_available(): return torch.device("cuda") - if torch.backends.mps.is_available(): - return torch.device("mps") return torch.device("cpu") diff --git a/train.py b/train.py index 6e9f0bc..27d730c 100644 --- a/train.py +++ b/train.py @@ -74,8 +74,6 @@ def get_device(requested_device): return torch.device(requested_device) if torch.cuda.is_available(): return torch.device("cuda") - if torch.backends.mps.is_available(): - return torch.device("mps") return torch.device("cpu")