|
@ -2,6 +2,7 @@ |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
import matplotlib.patches as patches |
|
|
|
|
|
|
|
|
from dataloader import extract_colors, preprocess_data |
|
|
from dataloader import extract_colors, preprocess_data |
|
|
from model import ColorTransformerModel |
|
|
from model import ColorTransformerModel |
|
@ -49,7 +50,7 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150): |
|
|
plot_preds(preds, fname=fname, dpi=dpi) |
|
|
plot_preds(preds, fname=fname, dpi=dpi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): |
|
|
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)): |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
@ -67,7 +68,7 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): |
|
|
|
|
|
|
|
|
N = len(colors) |
|
|
N = len(colors) |
|
|
# Create a plot with these hues in a circle |
|
|
# Create a plot with these hues in a circle |
|
|
fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True)) |
|
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) |
|
|
|
|
|
|
|
|
# Each wedge in the circle |
|
|
# Each wedge in the circle |
|
|
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 |
|
|
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 |
|
@ -79,7 +80,12 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): |
|
|
ax.set_xticks([]) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.axis("off") |
|
|
ax.axis("off") |
|
|
fig.tight_layout() |
|
|
# Overlay white circle |
|
|
|
|
|
radius = 1 / 3 |
|
|
|
|
|
circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2) |
|
|
|
|
|
ax.add_patch(circle) |
|
|
|
|
|
|
|
|
|
|
|
fig.tight_layout(pad=0) |
|
|
plt.savefig(f"{fname}.png", dpi=dpi) |
|
|
plt.savefig(f"{fname}.png", dpi=dpi) |
|
|
plt.close() |
|
|
plt.close() |
|
|
|
|
|
|
|
|