Browse Source

plotting improvements

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
72a1ad2971
  1. 23
      check.py
  2. BIN
      hsv.png
  3. 2
      hsv.py
  4. 2
      search.py

23
check.py

@ -1,8 +1,8 @@
# import matplotlib.patches as patches # import matplotlib.patches as patches
import matplotlib.patches as patches
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import matplotlib.patches as patches
from dataloader import extract_colors, preprocess_data from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel from model import ColorTransformerModel
@ -71,22 +71,35 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle # Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N): for i in range(N):
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0) ax.bar(
theta[i],
1,
width=width,
edgecolor="none",
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
)
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
ax.axis("off") ax.axis("off")
ax.set_aspect("equal")
# Overlay white circle # Overlay white circle
radius = 1 / 3 radius = 1 / 3
circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2) circle = patches.Circle(
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle) ax.add_patch(circle)
fig.tight_layout(pad=0) fig.tight_layout(pad=0)
plt.savefig(f"{fname}.png", dpi=dpi)
plt.savefig(f"{fname}.png", dpi=dpi, transparent=False)
plt.close() plt.close()

BIN
hsv.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 MiB

After

Width:  |  Height:  |  Size: 1.1 MiB

2
hsv.py

@ -8,7 +8,7 @@ if __name__ == "__main__":
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy() xkcd_rgb = rgb_tensor.numpy()
xkcd_hsv = rgb_to_hsv(xkcd_rgb) xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=300, figsize=(6,6)) plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=150, figsize=(6, 6))
rgb = np.eye(3) rgb = np.eye(3)
print("Pure RGB in Hue-Space:") print("Pure RGB in Hue-Space:")
print(rgb_to_hsv(rgb)[:, 0]) print(rgb_to_hsv(rgb)[:, 0])

2
search.py

@ -2,7 +2,7 @@ import subprocess
import sys import sys
from random import sample from random import sample
import numpy as np import numpy as np # noqa: F401
from lightning_sdk import Machine, Studio # noqa: F401 from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 100 NUM_JOBS = 100

Loading…
Cancel
Save