From ba9d6c034eb073dade0d3f6a9fb317c16b49b50a Mon Sep 17 00:00:00 2001 From: "Michael Pilosov, PhD" Date: Sun, 3 Mar 2024 00:57:20 +0000 Subject: [PATCH] update contract --- check.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/check.py b/check.py index cefe554..c15f542 100644 --- a/check.py +++ b/check.py @@ -55,13 +55,13 @@ def create_circle( xkcd_colors, _ = extract_colors() xkcd_colors = preprocess_data(xkcd_colors).to(M.device) - preds = M(xkcd_colors) + preds = M(xkcd_colors).detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy() plot_preds(preds, rgb_array, fname=fname, **kwargs) def plot_preds( - preds: torch.Tensor | np.ndarray, + preds: np.ndarray, rgb_values, fname: str, roll: bool = False, @@ -71,8 +71,6 @@ def plot_preds( fsize: int = 0, label: str = "", ): - if isinstance(preds, torch.Tensor): - preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds, :3] if roll: