Browse Source

update contract

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
ba9d6c034e
  1. 6
      check.py

6
check.py

@ -55,13 +55,13 @@ def create_circle(
xkcd_colors, _ = extract_colors() xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors) preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs) plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds( def plot_preds(
preds: torch.Tensor | np.ndarray, preds: np.ndarray,
rgb_values, rgb_values,
fname: str, fname: str,
roll: bool = False, roll: bool = False,
@ -71,8 +71,6 @@ def plot_preds(
fsize: int = 0, fsize: int = 0,
label: str = "", label: str = "",
): ):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel()) sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3] colors = rgb_values[sorted_inds, :3]
if roll: if roll:

Loading…
Cancel
Save