Browse Source

fine tune image

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
6899320927
  1. 12
      check.py
  2. BIN
      hsv.png
  3. 2
      hsv.py

12
check.py

@ -2,6 +2,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import matplotlib.patches as patches
from dataloader import extract_colors, preprocess_data from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel from model import ColorTransformerModel
@ -49,7 +50,7 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150):
plot_preds(preds, fname=fname, dpi=dpi) plot_preds(preds, fname=fname, dpi=dpi)
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)):
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy() rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor)
@ -67,7 +68,7 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150):
N = len(colors) N = len(colors)
# Create a plot with these hues in a circle # Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True)) fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle # Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2
@ -79,7 +80,12 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150):
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
ax.axis("off") ax.axis("off")
fig.tight_layout() # Overlay white circle
radius = 1 / 3
circle = patches.Circle((0, 0), radius, transform=ax.transData._b, color="white", zorder=2)
ax.add_patch(circle)
fig.tight_layout(pad=0)
plt.savefig(f"{fname}.png", dpi=dpi) plt.savefig(f"{fname}.png", dpi=dpi)
plt.close() plt.close()

BIN
hsv.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 329 KiB

After

Width:  |  Height:  |  Size: 2.7 MiB

2
hsv.py

@ -8,7 +8,7 @@ if __name__ == "__main__":
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy() xkcd_rgb = rgb_tensor.numpy()
xkcd_hsv = rgb_to_hsv(xkcd_rgb) xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True) plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=300, figsize=(6,6))
rgb = np.eye(3) rgb = np.eye(3)
print("Pure RGB in Hue-Space:") print("Pure RGB in Hue-Space:")
print(rgb_to_hsv(rgb)[:, 0]) print(rgb_to_hsv(rgb)[:, 0])

Loading…
Cancel
Save