From 7461406cf25fcda654924fcac150fef67f3cc6ba Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 15 Jan 2024 07:00:50 +0000 Subject: [PATCH] gpu updates --- check.py | 4 ++-- losses.py | 2 +- search.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/check.py b/check.py index f8e7b27..cc715ac 100644 --- a/check.py +++ b/check.py @@ -45,7 +45,7 @@ def create_circle(ckpt: str, fname: str): M = ckpt rgb_tensor, _ = extract_colors() - preds = M(rgb_tensor) + preds = M(rgb_tensor.to(M.device)) plot_preds(preds, fname=fname) @@ -55,7 +55,7 @@ def plot_preds(preds, fname: str, roll: bool = False): rgb_tensor = preprocess_data(rgb_tensor) if isinstance(preds, torch.Tensor): - preds = preds.detach().numpy() + preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds] if roll: diff --git a/losses.py b/losses.py index 2587538..30dfde4 100644 --- a/losses.py +++ b/losses.py @@ -54,7 +54,7 @@ def separation_loss(red, green, blue): def calculate_separation_loss(model): # Wrapper function to calculate separation loss - outputs = model(PURE_RGB) + outputs = model(PURE_RGB.to(model.device)) red, green, blue = outputs[0], outputs[1], outputs[2] return separation_loss(red, green, blue) diff --git a/search.py b/search.py index 3936683..81ad7a9 100644 --- a/search.py +++ b/search.py @@ -9,11 +9,11 @@ NUM_JOBS = 100 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name -studio = Studio() +# studio = Studio() # use the jobs plugin -studio.install_plugin("jobs") -job_plugin = studio.installed_plugins["jobs"] +# studio.install_plugin("jobs") +# job_plugin = studio.installed_plugins["jobs"] # do a sweep over learning rates