|
@ -46,24 +46,27 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150): |
|
|
M = ckpt |
|
|
M = ckpt |
|
|
|
|
|
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
plot_preds(preds, fname=fname, dpi=dpi) |
|
|
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, dpi=dpi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)): |
|
|
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds( |
|
|
|
|
|
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3) |
|
|
|
|
|
): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
if isinstance(preds, torch.Tensor): |
|
|
preds = preds.detach().cpu().numpy() |
|
|
preds = preds.detach().cpu().numpy() |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
colors = rgb_values[sorted_inds] |
|
|
colors = rgb_values[sorted_inds, :] |
|
|
if roll: |
|
|
if roll: |
|
|
# find white in colors, put it first. |
|
|
# find white in colors, put it first. |
|
|
white = np.array([1, 1, 1]) |
|
|
white = np.array([1, 1, 1]) |
|
|
white_idx = np.where((colors == white).all(axis=1))[0][0] |
|
|
white_idx = np.where((colors == white).all(axis=1)) |
|
|
colors = np.roll(colors, -white_idx, axis=0) |
|
|
if white_idx: |
|
|
|
|
|
white_idx = white_idx[0][0] |
|
|
|
|
|
colors = np.roll(colors, -white_idx, axis=0) |
|
|
|
|
|
else: |
|
|
|
|
|
print("no white, skipping") |
|
|
# print(white_idx, colors[:2]) |
|
|
# print(white_idx, colors[:2]) |
|
|
|
|
|
|
|
|
N = len(colors) |
|
|
N = len(colors) |
|
@ -76,10 +79,13 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3 |
|
|
|
|
|
|
|
|
for i in range(N): |
|
|
for i in range(N): |
|
|
ax.bar( |
|
|
ax.bar( |
|
|
|
|
|
# 2 * np.pi * preds[i], |
|
|
theta[i], |
|
|
theta[i], |
|
|
1, |
|
|
height=1, |
|
|
width=width, |
|
|
width=width, |
|
|
edgecolor="none", |
|
|
edgecolor="none", |
|
|
|
|
|
# facecolor=[rgb_values[i][1]]*3, |
|
|
|
|
|
# facecolor=rgb_values[i], |
|
|
facecolor=colors[i], |
|
|
facecolor=colors[i], |
|
|
bottom=0.0, |
|
|
bottom=0.0, |
|
|
zorder=1, |
|
|
zorder=1, |
|
|