Browse Source

split plotting for hsv. [todo]: finalize plotting

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
e66c938798
  1. 15
      check.py
  2. 16
      dataloader.py
  3. BIN
      hsv.png
  4. 14
      hsv.py
  5. 14
      utils.py

15
check.py

@ -43,11 +43,20 @@ def create_circle(ckpt: str, fname: str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
M = ckpt
rgb_tensor, names = extract_colors()
rgb_tensor, _ = extract_colors()
preds = M(rgb_tensor)
plot_preds(preds, fname=fname)
def plot_preds(preds, fname: str):
rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
if isinstance(preds, torch.Tensor):
preds = preds.detach().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds]
# find white in colors, put it first.
white = np.array([1, 1, 1])

16
dataloader.py

@ -1,21 +1,7 @@
import matplotlib.colors as mcolors
import torch
from torch.utils.data import DataLoader, TensorDataset
from utils import preprocess_data
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
xkcd_colors = mcolors.XKCD_COLORS
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
# Extracting the list of xkcd color names
xkcd_color_names = list(xkcd_colors.keys())
# Convert the list of RGB triples to a PyTorch tensor
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
return rgb_tensor, xkcd_color_names
from utils import extract_colors, preprocess_data
def create_dataloader(N: int = 50, **kwargs):

BIN
hsv.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 329 KiB

14
hsv.py

@ -0,0 +1,14 @@
import numpy as np
from matplotlib.colors import rgb_to_hsv
from check import plot_preds
from utils import extract_colors
if __name__ == "__main__":
rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy()
xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv")
rgb = np.eye(3)
print("Pure RGB in Hue-Space:")
print(rgb_to_hsv(rgb)[:, 0])

14
utils.py

@ -1,3 +1,4 @@
import matplotlib.colors as mcolors
import torch
@ -20,6 +21,19 @@ def preprocess_data(data, skip=True):
return new_data
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
xkcd_colors = mcolors.XKCD_COLORS
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
# Extracting the list of xkcd color names
xkcd_color_names = list(xkcd_colors.keys())
# Convert the list of RGB triples to a PyTorch tensor
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
return rgb_tensor, xkcd_color_names
PURE_RGB = preprocess_data(
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
)

Loading…
Cancel
Save