@ -219,6 +219,8 @@ def plot_preds(
dpi: int = 300,
figsize=(6, 6),
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll: