diff --git a/check.py b/check.py index 7a9bbca..f8e7b27 100644 --- a/check.py +++ b/check.py @@ -49,7 +49,7 @@ def create_circle(ckpt: str, fname: str): plot_preds(preds, fname=fname) -def plot_preds(preds, fname: str): +def plot_preds(preds, fname: str, roll: bool = False): rgb_tensor, _ = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) @@ -58,11 +58,12 @@ def plot_preds(preds, fname: str): preds = preds.detach().numpy() sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds] - # find white in colors, put it first. - white = np.array([1, 1, 1]) - white_idx = np.where((colors == white).all(axis=1))[0][0] - colors = np.roll(colors, -white_idx, axis=0) - # print(white_idx, colors[:2]) + if roll: + # find white in colors, put it first. + white = np.array([1, 1, 1]) + white_idx = np.where((colors == white).all(axis=1))[0][0] + colors = np.roll(colors, -white_idx, axis=0) + # print(white_idx, colors[:2]) N = len(colors) # Create a plot with these hues in a circle diff --git a/hsv.png b/hsv.png index 839bcc1..ce5df9b 100644 Binary files a/hsv.png and b/hsv.png differ