|
@ -55,13 +55,13 @@ def create_circle( |
|
|
|
|
|
|
|
|
xkcd_colors, _ = extract_colors() |
|
|
xkcd_colors, _ = extract_colors() |
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
preds = M(xkcd_colors) |
|
|
preds = M(xkcd_colors).detach().cpu().numpy() |
|
|
rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds( |
|
|
def plot_preds( |
|
|
preds: torch.Tensor | np.ndarray, |
|
|
preds: np.ndarray, |
|
|
rgb_values, |
|
|
rgb_values, |
|
|
fname: str, |
|
|
fname: str, |
|
|
roll: bool = False, |
|
|
roll: bool = False, |
|
@ -71,8 +71,6 @@ def plot_preds( |
|
|
fsize: int = 0, |
|
|
fsize: int = 0, |
|
|
label: str = "", |
|
|
label: str = "", |
|
|
): |
|
|
): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
|
|
|
preds = preds.detach().cpu().numpy() |
|
|
|
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
colors = rgb_values[sorted_inds, :3] |
|
|
colors = rgb_values[sorted_inds, :3] |
|
|
if roll: |
|
|
if roll: |
|
|