diff --git a/check.py b/check.py index 43145a4..afefcda 100644 --- a/check.py +++ b/check.py @@ -1,8 +1,8 @@ # import matplotlib.patches as patches +import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np import torch -import matplotlib.patches as patches from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel @@ -71,22 +71,35 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3 fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Each wedge in the circle - theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 + theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2 width = 2 * np.pi / (N) # equal size for each wedge for i in range(N): - ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0) + ax.bar( + theta[i], + 1, + width=width, + edgecolor="none", + facecolor=colors[i], + bottom=0.0, + zorder=1, + alpha=1, + ) ax.set_xticks([]) ax.set_yticks([]) ax.axis("off") + ax.set_aspect("equal") # Overlay white circle radius = 1 / 3 - circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2) + circle = patches.Circle( + (0, 0), radius, transform=ax.transData._b, color="white", zorder=2 + ) ax.add_patch(circle) fig.tight_layout(pad=0) - plt.savefig(f"{fname}.png", dpi=dpi) + + plt.savefig(f"{fname}.png", dpi=dpi, transparent=False) plt.close() diff --git a/hsv.png b/hsv.png index 18a5fe5..a23a00c 100644 Binary files a/hsv.png and b/hsv.png differ diff --git a/hsv.py b/hsv.py index 7f80ffe..d453312 100644 --- a/hsv.py +++ b/hsv.py @@ -8,7 +8,7 @@ if __name__ == "__main__": rgb_tensor, _ = extract_colors() xkcd_rgb = rgb_tensor.numpy() xkcd_hsv = rgb_to_hsv(xkcd_rgb) - plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=300, figsize=(6,6)) + plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=150, figsize=(6, 6)) rgb = np.eye(3) print("Pure RGB in Hue-Space:") print(rgb_to_hsv(rgb)[:, 0]) diff --git a/search.py b/search.py index 05b577b..6b0b152 100644 --- a/search.py +++ b/search.py @@ -2,7 +2,7 @@ import subprocess import sys from random import sample -import numpy as np +import numpy as np # noqa: F401 from lightning_sdk import Machine, Studio # noqa: F401 NUM_JOBS = 100