diff --git a/check.py b/check.py index cefe554..c15f542 100644 --- a/check.py +++ b/check.py @@ -55,13 +55,13 @@ def create_circle( xkcd_colors, _ = extract_colors() xkcd_colors = preprocess_data(xkcd_colors).to(M.device) - preds = M(xkcd_colors) + preds = M(xkcd_colors).detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy() plot_preds(preds, rgb_array, fname=fname, **kwargs) def plot_preds( - preds: torch.Tensor | np.ndarray, + preds: np.ndarray, rgb_values, fname: str, roll: bool = False, @@ -71,8 +71,6 @@ def plot_preds( fsize: int = 0, label: str = "", ): - if isinstance(preds, torch.Tensor): - preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: