diff --git a/.gitignore b/.gitignore index 0facd2b..651f70e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__/ out/ *.png .sw[opqr] +*.tar.gz diff --git a/check.py b/check.py index 1fb976d..7f13648 100644 --- a/check.py +++ b/check.py @@ -8,8 +8,10 @@ from model import ColorTransformerModel import numpy as np import matplotlib.pyplot as plt + # import matplotlib.colors as mcolors + def make_image(ckpt: str, fname: str, color=True): M = ColorTransformerModel.load_from_checkpoint(ckpt) @@ -41,6 +43,7 @@ def make_image(ckpt: str, fname: str, color=True): def create_circle(ckpt: str, fname: str): M = ColorTransformerModel.load_from_checkpoint(ckpt) + M.eval() rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) @@ -58,7 +61,7 @@ def create_circle(ckpt: str, fname: str): fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) # Each wedge in the circle - theta = np.linspace(0, 2 * np.pi, N+1) + np.pi / 2 + theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 width = 2 * np.pi / (N) # equal size for each wedge for i in range(N): @@ -69,6 +72,7 @@ def create_circle(ckpt: str, fname: str): ax.axis("off") fig.tight_layout() plt.savefig(f"{fname}.png", dpi=300) + plt.close() if __name__ == "__main__":