Browse Source

gpu updates

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
7461406cf2
  1. 4
      check.py
  2. 2
      losses.py
  3. 6
      search.py

4
check.py

@ -45,7 +45,7 @@ def create_circle(ckpt: str, fname: str):
M = ckpt
rgb_tensor, _ = extract_colors()
preds = M(rgb_tensor)
preds = M(rgb_tensor.to(M.device))
plot_preds(preds, fname=fname)
@ -55,7 +55,7 @@ def plot_preds(preds, fname: str, roll: bool = False):
rgb_tensor = preprocess_data(rgb_tensor)
if isinstance(preds, torch.Tensor):
preds = preds.detach().numpy()
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds]
if roll:

2
losses.py

@ -54,7 +54,7 @@ def separation_loss(red, green, blue):
def calculate_separation_loss(model):
# Wrapper function to calculate separation loss
outputs = model(PURE_RGB)
outputs = model(PURE_RGB.to(model.device))
red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue)

6
search.py

@ -9,11 +9,11 @@ NUM_JOBS = 100
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
studio = Studio()
# studio = Studio()
# use the jobs plugin
studio.install_plugin("jobs")
job_plugin = studio.installed_plugins["jobs"]
# studio.install_plugin("jobs")
# job_plugin = studio.installed_plugins["jobs"]
# do a sweep over learning rates

Loading…
Cancel
Save