|
|
@ -49,7 +49,7 @@ def create_circle(ckpt: str, fname: str): |
|
|
|
plot_preds(preds, fname=fname) |
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(preds, fname: str): |
|
|
|
def plot_preds(preds, fname: str, roll: bool = False): |
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
@ -58,11 +58,12 @@ def plot_preds(preds, fname: str): |
|
|
|
preds = preds.detach().numpy() |
|
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
|
colors = rgb_values[sorted_inds] |
|
|
|
# find white in colors, put it first. |
|
|
|
white = np.array([1, 1, 1]) |
|
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|
|
colors = np.roll(colors, -white_idx, axis=0) |
|
|
|
# print(white_idx, colors[:2]) |
|
|
|
if roll: |
|
|
|
# find white in colors, put it first. |
|
|
|
white = np.array([1, 1, 1]) |
|
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|
|
colors = np.roll(colors, -white_idx, axis=0) |
|
|
|
# print(white_idx, colors[:2]) |
|
|
|
|
|
|
|
N = len(colors) |
|
|
|
# Create a plot with these hues in a circle |
|
|
|