Browse Source

plotting improvements

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
72a1ad2971
  1. 23
      check.py
  2. BIN
      hsv.png
  3. 2
      hsv.py
  4. 2
      search.py

23
check.py

@ -1,8 +1,8 @@
# import matplotlib.patches as patches
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import matplotlib.patches as patches
from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel
@ -71,22 +71,35 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0)
ax.bar(
theta[i],
1,
width=width,
edgecolor="none",
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
)
ax.set_xticks([])
ax.set_yticks([])
ax.axis("off")
ax.set_aspect("equal")
# Overlay white circle
radius = 1 / 3
circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2)
circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
fig.tight_layout(pad=0)
plt.savefig(f"{fname}.png", dpi=dpi)
plt.savefig(f"{fname}.png", dpi=dpi, transparent=False)
plt.close()

BIN
hsv.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 MiB

After

Width:  |  Height:  |  Size: 1.1 MiB

2
hsv.py

@ -8,7 +8,7 @@ if __name__ == "__main__":
rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy()
xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=300, figsize=(6,6))
plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=150, figsize=(6, 6))
rgb = np.eye(3)
print("Pure RGB in Hue-Space:")
print(rgb_to_hsv(rgb)[:, 0])

2
search.py

@ -2,7 +2,7 @@ import subprocess
import sys
from random import sample
import numpy as np
import numpy as np # noqa: F401
from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 100

Loading…
Cancel
Save