diff --git a/check.py b/check.py index c3c452d..43145a4 100644 --- a/check.py +++ b/check.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +import matplotlib.patches as patches from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel @@ -49,7 +50,7 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150): plot_preds(preds, fname=fname, dpi=dpi) -def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): +def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)): rgb_tensor, _ = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) @@ -67,7 +68,7 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): N = len(colors) # Create a plot with these hues in a circle - fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True)) + fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Each wedge in the circle theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 @@ -79,7 +80,12 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): ax.set_xticks([]) ax.set_yticks([]) ax.axis("off") - fig.tight_layout() + # Overlay white circle + radius = 1 / 3 + circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2) + ax.add_patch(circle) + + fig.tight_layout(pad=0) plt.savefig(f"{fname}.png", dpi=dpi) plt.close() diff --git a/hsv.png b/hsv.png index 839bcc1..18a5fe5 100644 Binary files a/hsv.png and b/hsv.png differ diff --git a/hsv.py b/hsv.py index 5ef389c..7f80ffe 100644 --- a/hsv.py +++ b/hsv.py @@ -8,7 +8,7 @@ if __name__ == "__main__": rgb_tensor, _ = extract_colors() xkcd_rgb = rgb_tensor.numpy() xkcd_hsv = rgb_to_hsv(xkcd_rgb) - plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True) + plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=300, figsize=(6,6)) rgb = np.eye(3) print("Pure RGB in Hue-Space:") print(rgb_to_hsv(rgb)[:, 0])