Browse Source

update contract

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
ba9d6c034e
  1. 6
      check.py

6
check.py

@ -55,13 +55,13 @@ def create_circle(
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds(
preds: torch.Tensor | np.ndarray,
preds: np.ndarray,
rgb_values,
fname: str,
roll: bool = False,
@ -71,8 +71,6 @@ def plot_preds(
fsize: int = 0,
label: str = "",
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:

Loading…
Cancel
Save