|
|
@ -45,7 +45,7 @@ def create_circle(ckpt: str, fname: str): |
|
|
|
M = ckpt |
|
|
|
|
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
preds = M(rgb_tensor) |
|
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
|
plot_preds(preds, fname=fname) |
|
|
|
|
|
|
|
|
|
|
@ -55,7 +55,7 @@ def plot_preds(preds, fname: str, roll: bool = False): |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
|
|
|
|
if isinstance(preds, torch.Tensor): |
|
|
|
preds = preds.detach().numpy() |
|
|
|
preds = preds.detach().cpu().numpy() |
|
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
|
colors = rgb_values[sorted_inds] |
|
|
|
if roll: |
|
|
|