|
@ -49,7 +49,7 @@ def create_circle(ckpt: str, fname: str): |
|
|
plot_preds(preds, fname=fname) |
|
|
plot_preds(preds, fname=fname) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(preds, fname: str): |
|
|
def plot_preds(preds, fname: str, roll: bool = False): |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
@ -58,6 +58,7 @@ def plot_preds(preds, fname: str): |
|
|
preds = preds.detach().numpy() |
|
|
preds = preds.detach().numpy() |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
colors = rgb_values[sorted_inds] |
|
|
colors = rgb_values[sorted_inds] |
|
|
|
|
|
if roll: |
|
|
# find white in colors, put it first. |
|
|
# find white in colors, put it first. |
|
|
white = np.array([1, 1, 1]) |
|
|
white = np.array([1, 1, 1]) |
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|