|
@ -1,8 +1,8 @@ |
|
|
# import matplotlib.patches as patches |
|
|
# import matplotlib.patches as patches |
|
|
|
|
|
import matplotlib.patches as patches |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
import matplotlib.patches as patches |
|
|
|
|
|
|
|
|
|
|
|
from dataloader import extract_colors, preprocess_data |
|
|
from dataloader import extract_colors, preprocess_data |
|
|
from model import ColorTransformerModel |
|
|
from model import ColorTransformerModel |
|
@ -71,22 +71,35 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3 |
|
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) |
|
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) |
|
|
|
|
|
|
|
|
# Each wedge in the circle |
|
|
# Each wedge in the circle |
|
|
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 |
|
|
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2 |
|
|
width = 2 * np.pi / (N) # equal size for each wedge |
|
|
width = 2 * np.pi / (N) # equal size for each wedge |
|
|
|
|
|
|
|
|
for i in range(N): |
|
|
for i in range(N): |
|
|
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0) |
|
|
ax.bar( |
|
|
|
|
|
theta[i], |
|
|
|
|
|
1, |
|
|
|
|
|
width=width, |
|
|
|
|
|
edgecolor="none", |
|
|
|
|
|
facecolor=colors[i], |
|
|
|
|
|
bottom=0.0, |
|
|
|
|
|
zorder=1, |
|
|
|
|
|
alpha=1, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
ax.set_xticks([]) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.axis("off") |
|
|
ax.axis("off") |
|
|
|
|
|
ax.set_aspect("equal") |
|
|
# Overlay white circle |
|
|
# Overlay white circle |
|
|
radius = 1 / 3 |
|
|
radius = 1 / 3 |
|
|
circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2) |
|
|
circle = patches.Circle( |
|
|
|
|
|
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2 |
|
|
|
|
|
) |
|
|
ax.add_patch(circle) |
|
|
ax.add_patch(circle) |
|
|
|
|
|
|
|
|
fig.tight_layout(pad=0) |
|
|
fig.tight_layout(pad=0) |
|
|
plt.savefig(f"{fname}.png", dpi=dpi) |
|
|
|
|
|
|
|
|
plt.savefig(f"{fname}.png", dpi=dpi, transparent=False) |
|
|
plt.close() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|