You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
4.7 KiB
171 lines
4.7 KiB
9 months ago
|
import argparse
|
||
|
from pathlib import Path
|
||
|
|
||
|
import matplotlib.colors as mcolors
|
||
|
import matplotlib.patches as patches
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
from hilbertcurve.hilbertcurve import HilbertCurve
|
||
|
|
||
|
from check import plot_preds
|
||
|
|
||
|
# Extract XKCD colors
|
||
|
colors = list(mcolors.XKCD_COLORS.keys())
|
||
|
rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors]
|
||
|
|
||
|
# Parse command-line arguments
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting")
|
||
|
parser.add_argument("--seed", type=int, default=21, help="seed for UMAP")
|
||
|
parser.add_argument("--dpi", type=int, default=300, help="dpi for saving")
|
||
|
parser.add_argument("--size", type=float, default=6.0, help="size of figure")
|
||
|
parser.add_argument(
|
||
|
"--fontsize",
|
||
|
type=float,
|
||
|
default=0,
|
||
|
help="fontsize of annotation (default: 0 = None)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--radius", type=float, default=1 / 2, help="inner radius of circle"
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
KIND = args.sort_by
|
||
|
SEED = args.seed
|
||
|
DPI = args.dpi
|
||
|
SIZE = args.size
|
||
|
FONTSIZE = args.fontsize
|
||
|
INNER_RADIUS = args.radius
|
||
|
DIR = "/teamspace/studios/colors/umap"
|
||
|
|
||
|
|
||
|
prefix = ""
|
||
|
if KIND == "umap":
|
||
|
prefix = f"{SEED:04d}_"
|
||
|
FDIR = f"{DIR}/{KIND}"
|
||
|
Path(FDIR).mkdir(exist_ok=True, parents=True)
|
||
|
fname = f"{FDIR}/{prefix}sorted_colors_circle.png"
|
||
|
|
||
|
|
||
|
|
||
|
if KIND in ("lex", "alpha", "abc"):
|
||
|
preds = np.array(colors)
|
||
|
|
||
|
elif KIND == "umap":
|
||
|
# from umap import UMAP
|
||
|
from cuml import UMAP
|
||
|
|
||
|
# Use UMAP to create a 1D representation
|
||
|
reducer = UMAP(
|
||
|
n_components=1,
|
||
|
n_neighbors=250,
|
||
|
min_dist=1e-2,
|
||
|
metric="euclidean",
|
||
|
random_state=SEED,
|
||
|
negative_sample_rate=2,
|
||
|
)
|
||
|
embedding = reducer.fit_transform(np.array(rgb_values))
|
||
|
|
||
|
# Sort colors by the 1D representation
|
||
|
preds = embedding[:, 0]
|
||
|
del reducer, embedding
|
||
|
|
||
|
elif KIND in ("cielab", "lab", "ciede2000"):
|
||
|
from skimage.color import deltaE_ciede2000, rgb2lab
|
||
|
|
||
|
# CIELAB
|
||
|
# Convert RGB values to CIELAB
|
||
|
lab_values = rgb2lab([rgb_values])
|
||
|
|
||
|
# Reference color for sorting (can be the first color or any other reference point)
|
||
|
reference_color = lab_values[0]
|
||
|
|
||
|
# Compute CIEDE2000 distances of all colors to the reference color
|
||
|
distances = [deltaE_ciede2000(reference_color, color) for color in lab_values]
|
||
|
|
||
|
# Sort colors by their CIEDE2000 distance to the reference color
|
||
|
# preds = distances).flatten() # awful
|
||
|
lab_values_flat = lab_values.reshape(-1, 3)
|
||
|
# Sort colors based on the L* value in the CIELAB space
|
||
|
# 0 corresponds to the L* channel
|
||
|
preds = lab_values_flat[:, 0]
|
||
|
|
||
|
elif KIND == "hsv":
|
||
|
from matplotlib.colors import rgb_to_hsv
|
||
|
|
||
|
# Convert RGB values to HSV
|
||
|
hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values])
|
||
|
|
||
|
# Sort colors based on the hue value
|
||
|
# 0 corresponds to the hue component
|
||
|
preds = hsv_values[:, 0]
|
||
|
else:
|
||
|
raise ValueError(f"Unknown kind: {KIND}")
|
||
|
|
||
|
sorted_indices = np.argsort(preds)
|
||
|
|
||
|
# Save the sorted indices to disk
|
||
|
# if (KIND == "umap") or (KIND != "umap"):
|
||
|
PDIR = f"scripts/{KIND}"
|
||
|
Path(PDIR).mkdir(parents=True, exist_ok=True)
|
||
|
file_path = f"{PDIR}/{SEED:06d}.npy"
|
||
|
np.save(file_path, preds.ravel())
|
||
|
print(f"Predictions saved to {file_path}")
|
||
|
|
||
|
# Sort colors by the 1D representation
|
||
|
sorted_colors = [colors[i] for i in sorted_indices]
|
||
|
|
||
|
plot_preds(
|
||
|
preds,
|
||
|
np.array(rgb_values),
|
||
|
fname,
|
||
|
roll=False,
|
||
|
dpi=DPI,
|
||
|
inner_radius=INNER_RADIUS,
|
||
|
figsize=(SIZE, SIZE),
|
||
|
fsize=FONTSIZE,
|
||
|
label=f"{KIND.upper()}",
|
||
|
)
|
||
|
print(f"saved {fname}")
|
||
|
|
||
|
HILBERT = False
|
||
|
|
||
|
if HILBERT:
|
||
|
# Create Hilbert curve
|
||
|
# We'll set the order such that the number of positions is greater than or equal to the number of colors
|
||
|
hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors))))
|
||
|
hilbert_curve = HilbertCurve(hilbert_order, 2)
|
||
|
|
||
|
# Create an image for visualization
|
||
|
image_size = 2**hilbert_order
|
||
|
image = np.ones((image_size, image_size, 3))
|
||
|
|
||
|
for i, color in enumerate(sorted_colors):
|
||
|
# Convert linear index to Hilbert coordinates
|
||
|
coords = hilbert_curve.point_from_distance(i)
|
||
|
image[coords[1], coords[0]] = mcolors.to_rgb(color)
|
||
|
|
||
|
# annotation in upper right
|
||
|
# Display the image
|
||
|
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
|
||
|
ax.imshow(image)
|
||
|
ax.annotate(
|
||
|
f"{KIND.upper()}",
|
||
|
(1.0, 1.0),
|
||
|
ha="right",
|
||
|
va="top",
|
||
|
size=FONTSIZE,
|
||
|
xycoords="axes fraction",
|
||
|
)
|
||
|
ax.axis("off")
|
||
|
ax.set_aspect("equal")
|
||
|
fig.tight_layout()
|
||
|
fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png"
|
||
|
fig.savefig(
|
||
|
fname,
|
||
|
dpi=DPI,
|
||
|
transparent=True,
|
||
|
# bbox_inches="tight",
|
||
|
# pad_inches=0
|
||
|
)
|
||
|
print(f"Saved {fname}")
|