diff --git a/check.py b/check.py index 2861a23..7a9bbca 100644 --- a/check.py +++ b/check.py @@ -43,11 +43,20 @@ def create_circle(ckpt: str, fname: str): M = ColorTransformerModel.load_from_checkpoint(ckpt) else: M = ckpt - rgb_tensor, names = extract_colors() + + rgb_tensor, _ = extract_colors() + preds = M(rgb_tensor) + plot_preds(preds, fname=fname) + + +def plot_preds(preds, fname: str): + rgb_tensor, _ = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) - preds = M(rgb_tensor) - sorted_inds = np.argsort(preds.detach().numpy().ravel()) + + if isinstance(preds, torch.Tensor): + preds = preds.detach().numpy() + sorted_inds = np.argsort(preds.ravel()) colors = rgb_values[sorted_inds] # find white in colors, put it first. white = np.array([1, 1, 1]) diff --git a/dataloader.py b/dataloader.py index 9551ea1..66a4c0f 100644 --- a/dataloader.py +++ b/dataloader.py @@ -1,21 +1,7 @@ -import matplotlib.colors as mcolors import torch from torch.utils.data import DataLoader, TensorDataset -from utils import preprocess_data - - -def extract_colors(): - # Extracting the list of xkcd colors as RGB triples - xkcd_colors = mcolors.XKCD_COLORS - rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] - - # Extracting the list of xkcd color names - xkcd_color_names = list(xkcd_colors.keys()) - - # Convert the list of RGB triples to a PyTorch tensor - rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) - return rgb_tensor, xkcd_color_names +from utils import extract_colors, preprocess_data def create_dataloader(N: int = 50, **kwargs): diff --git a/hsv.png b/hsv.png new file mode 100644 index 0000000..839bcc1 Binary files /dev/null and b/hsv.png differ diff --git a/hsv.py b/hsv.py new file mode 100644 index 0000000..9dfccd2 --- /dev/null +++ b/hsv.py @@ -0,0 +1,14 @@ +import numpy as np +from matplotlib.colors import rgb_to_hsv + +from check import plot_preds +from utils import extract_colors + +if __name__ == "__main__": + rgb_tensor, _ = extract_colors() + xkcd_rgb = rgb_tensor.numpy() + xkcd_hsv = rgb_to_hsv(xkcd_rgb) + plot_preds(xkcd_hsv[:, 0], fname="hsv") + rgb = np.eye(3) + print("Pure RGB in Hue-Space:") + print(rgb_to_hsv(rgb)[:, 0]) diff --git a/utils.py b/utils.py index 8ce1708..4bea669 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,4 @@ +import matplotlib.colors as mcolors import torch @@ -20,6 +21,19 @@ def preprocess_data(data, skip=True): return new_data +def extract_colors(): + # Extracting the list of xkcd colors as RGB triples + xkcd_colors = mcolors.XKCD_COLORS + rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] + + # Extracting the list of xkcd color names + xkcd_color_names = list(xkcd_colors.keys()) + + # Convert the list of RGB triples to a PyTorch tensor + rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) + return rgb_tensor, xkcd_color_names + + PURE_RGB = preprocess_data( torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32) )