Browse Source

plotting updates

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
62e0764a70
  1. 26
      check.py
  2. BIN
      hsv.png
  3. 5
      hsv.py
  4. 4
      main.py
  5. 2
      out/index.html

26
check.py

@ -46,24 +46,27 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150):
M = ckpt
rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor.to(M.device))
plot_preds(preds, fname=fname, dpi=dpi)
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, dpi=dpi)
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)):
rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds]
colors = rgb_values[sorted_inds, :]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))[0][0]
colors = np.roll(colors, -white_idx, axis=0)
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
@ -76,10 +79,13 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3
for i in range(N):
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
1,
height=1,
width=width,
edgecolor="none",
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,

BIN
hsv.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 2.8 MiB

5
hsv.py

@ -7,8 +7,11 @@ from utils import extract_colors
if __name__ == "__main__":
rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy()
# xkcd_rgb = np.random.rand(1000, 3)
xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=150, figsize=(6, 6))
plot_preds(
xkcd_hsv[:, 0], xkcd_rgb, fname="hsv", roll=False, dpi=300, figsize=(6, 6)
)
rgb = np.eye(3)
print("Pure RGB in Hue-Space:")
print(rgb_to_hsv(rgb)[:, 0])

4
main.py

@ -72,13 +72,13 @@ if __name__ == "__main__":
save_img_callback = SaveImageCallback(
save_interval=0,
final_dir="out",
final_dir=None,
)
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_dataloader(
N=1e7,
N=1e5,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,

2
out/index.html

@ -74,7 +74,7 @@
for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100
let imageName;
if (i == 21) {
if (i == -21) {
imageName = 'hsv.png';
} else {
imageName = 'v' + i + '.png';

Loading…
Cancel
Save