diff --git a/scripts/sortcolor.py b/scripts/sortcolor.py index 5d2ebfa..6d48c76 100644 --- a/scripts/sortcolor.py +++ b/scripts/sortcolor.py @@ -219,6 +219,8 @@ def plot_preds( dpi: int = 300, figsize=(6, 6), ): + if isinstance(preds, torch.Tensor): + preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: