Browse Source

deterministic figs

new-sep-loss
Michael Pilosov 11 months ago
parent
commit
c9a5d0062c
  1. 1
      .gitignore
  2. 6
      check.py

1
.gitignore

@ -3,3 +3,4 @@ __pycache__/
out/ out/
*.png *.png
.sw[opqr] .sw[opqr]
*.tar.gz

6
check.py

@ -8,8 +8,10 @@ from model import ColorTransformerModel
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# import matplotlib.colors as mcolors # import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True): def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
@ -41,6 +43,7 @@ def make_image(ckpt: str, fname: str, color=True):
def create_circle(ckpt: str, fname: str): def create_circle(ckpt: str, fname: str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
M.eval()
rgb_tensor, names = extract_colors() rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy() rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor)
@ -58,7 +61,7 @@ def create_circle(ckpt: str, fname: str):
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
# Each wedge in the circle # Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N+1) + np.pi / 2 theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N): for i in range(N):
@ -69,6 +72,7 @@ def create_circle(ckpt: str, fname: str):
ax.axis("off") ax.axis("off")
fig.tight_layout() fig.tight_layout()
plt.savefig(f"{fname}.png", dpi=300) plt.savefig(f"{fname}.png", dpi=300)
plt.close()
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save