diff --git a/baseline.py b/baseline.py new file mode 100644 index 0000000..216808b --- /dev/null +++ b/baseline.py @@ -0,0 +1,170 @@ +import argparse +from pathlib import Path + +import matplotlib.colors as mcolors +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from hilbertcurve.hilbertcurve import HilbertCurve + +from check import plot_preds + +# Extract XKCD colors +colors = list(mcolors.XKCD_COLORS.keys()) +rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors] + +# Parse command-line arguments +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting") +parser.add_argument("--seed", type=int, default=21, help="seed for UMAP") +parser.add_argument("--dpi", type=int, default=300, help="dpi for saving") +parser.add_argument("--size", type=float, default=6.0, help="size of figure") +parser.add_argument( + "--fontsize", + type=float, + default=0, + help="fontsize of annotation (default: 0 = None)", +) +parser.add_argument( + "--radius", type=float, default=1 / 2, help="inner radius of circle" +) +args = parser.parse_args() +KIND = args.sort_by +SEED = args.seed +DPI = args.dpi +SIZE = args.size +FONTSIZE = args.fontsize +INNER_RADIUS = args.radius +DIR = "/teamspace/studios/colors/umap" + + +prefix = "" +if KIND == "umap": + prefix = f"{SEED:04d}_" +FDIR = f"{DIR}/{KIND}" +Path(FDIR).mkdir(exist_ok=True, parents=True) +fname = f"{FDIR}/{prefix}sorted_colors_circle.png" + + + +if KIND in ("lex", "alpha", "abc"): + preds = np.array(colors) + +elif KIND == "umap": + # from umap import UMAP + from cuml import UMAP + + # Use UMAP to create a 1D representation + reducer = UMAP( + n_components=1, + n_neighbors=250, + min_dist=1e-2, + metric="euclidean", + random_state=SEED, + negative_sample_rate=2, + ) + embedding = reducer.fit_transform(np.array(rgb_values)) + + # Sort colors by the 1D representation + preds = embedding[:, 0] + del reducer, embedding + +elif KIND in ("cielab", "lab", "ciede2000"): + from skimage.color import deltaE_ciede2000, rgb2lab + + # CIELAB + # Convert RGB values to CIELAB + lab_values = rgb2lab([rgb_values]) + + # Reference color for sorting (can be the first color or any other reference point) + reference_color = lab_values[0] + + # Compute CIEDE2000 distances of all colors to the reference color + distances = [deltaE_ciede2000(reference_color, color) for color in lab_values] + + # Sort colors by their CIEDE2000 distance to the reference color + # preds = distances).flatten() # awful + lab_values_flat = lab_values.reshape(-1, 3) + # Sort colors based on the L* value in the CIELAB space + # 0 corresponds to the L* channel + preds = lab_values_flat[:, 0] + +elif KIND == "hsv": + from matplotlib.colors import rgb_to_hsv + + # Convert RGB values to HSV + hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values]) + + # Sort colors based on the hue value + # 0 corresponds to the hue component + preds = hsv_values[:, 0] +else: + raise ValueError(f"Unknown kind: {KIND}") + +sorted_indices = np.argsort(preds) + +# Save the sorted indices to disk +# if (KIND == "umap") or (KIND != "umap"): +PDIR = f"scripts/{KIND}" +Path(PDIR).mkdir(parents=True, exist_ok=True) +file_path = f"{PDIR}/{SEED:06d}.npy" +np.save(file_path, preds.ravel()) +print(f"Predictions saved to {file_path}") + +# Sort colors by the 1D representation +sorted_colors = [colors[i] for i in sorted_indices] + +plot_preds( + preds, + np.array(rgb_values), + fname, + roll=False, + dpi=DPI, + inner_radius=INNER_RADIUS, + figsize=(SIZE, SIZE), + fsize=FONTSIZE, + label=f"{KIND.upper()}", +) +print(f"saved {fname}") + +HILBERT = False + +if HILBERT: + # Create Hilbert curve + # We'll set the order such that the number of positions is greater than or equal to the number of colors + hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors)))) + hilbert_curve = HilbertCurve(hilbert_order, 2) + + # Create an image for visualization + image_size = 2**hilbert_order + image = np.ones((image_size, image_size, 3)) + + for i, color in enumerate(sorted_colors): + # Convert linear index to Hilbert coordinates + coords = hilbert_curve.point_from_distance(i) + image[coords[1], coords[0]] = mcolors.to_rgb(color) + + # annotation in upper right + # Display the image + fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE)) + ax.imshow(image) + ax.annotate( + f"{KIND.upper()}", + (1.0, 1.0), + ha="right", + va="top", + size=FONTSIZE, + xycoords="axes fraction", + ) + ax.axis("off") + ax.set_aspect("equal") + fig.tight_layout() + fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png" + fig.savefig( + fname, + dpi=DPI, + transparent=True, + # bbox_inches="tight", + # pad_inches=0 + ) + print(f"Saved {fname}") diff --git a/check.py b/check.py index bd253f3..6160e39 100644 --- a/check.py +++ b/check.py @@ -1,6 +1,6 @@ # import matplotlib.patches as patches from pathlib import Path -from typing import Union +from typing import Union, Tuple import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -61,13 +61,15 @@ def create_circle( def plot_preds( - preds, + preds: torch.Tensor | np.ndarray, rgb_values, fname: str, roll: bool = False, - inner_radius=1 / 3, + inner_radius: float = 1 / 3, dpi: int = 300, - figsize=(6, 6), + figsize: Tuple[float] = (6, 6), + fsize: int = 0, + label: str = "", ): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() @@ -114,7 +116,7 @@ def plot_preds( ax.set_aspect("equal") ax.axis("off") radius = 1 - ax.set_ylim(-radius, radius) + ax.set_ylim(0, radius) # Overlay white circle circle = patches.Circle( @@ -122,11 +124,20 @@ def plot_preds( ) ax.add_patch(circle) + if fsize > 0.0: + center = (0, 0) + ax.annotate( + label, + center, + ha="center", + va="center", + size=fsize, + color="black", + ) + fig.tight_layout(pad=0) - plt.savefig( - f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight" - ) + plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight") plt.close() diff --git a/makefile b/makefile index 8acc8c7..6c9e806 100644 --- a/makefile +++ b/makefile @@ -56,10 +56,10 @@ umap: done sort_umap: - python scripts/sortcolor.py -s umap --dpi 300 --seed 21 + python baseline.py -s umap --dpi 300 --seed 21 parallel_umap: - parallel -j 4 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100) + parallel -j 4 python baseline.py -s umap --dpi 300 --seed ::: $$(seq 1 100) parallel_check: parallel -j 4 python check.py -v ::: $$(seq 0 99)