Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Pilosov, PhD
bd579913d9 update baseline. note nondeterminism 2024-03-03 00:07:28 +00:00
Michael Pilosov, PhD
8522aa0386 isinstance 2024-03-02 23:51:05 +00:00
Michael Pilosov, PhD
6321157df1 align plotting code with defaults 2024-03-02 23:49:41 +00:00
5 changed files with 202 additions and 13 deletions

170
baseline.py Normal file
View File

@ -0,0 +1,170 @@
import argparse
from pathlib import Path
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from hilbertcurve.hilbertcurve import HilbertCurve
from check import plot_preds
# Extract XKCD colors
colors = list(mcolors.XKCD_COLORS.keys())
rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors]
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting")
parser.add_argument("--seed", type=int, default=21, help="seed for UMAP")
parser.add_argument("--dpi", type=int, default=300, help="dpi for saving")
parser.add_argument("--size", type=float, default=6.0, help="size of figure")
parser.add_argument(
"--fontsize",
type=float,
default=0,
help="fontsize of annotation (default: 0 = None)",
)
parser.add_argument(
"--radius", type=float, default=1 / 2, help="inner radius of circle"
)
args = parser.parse_args()
KIND = args.sort_by
SEED = args.seed
DPI = args.dpi
SIZE = args.size
FONTSIZE = args.fontsize
INNER_RADIUS = args.radius
DIR = "/teamspace/studios/colors/umap"
prefix = ""
if KIND == "umap":
prefix = f"{SEED:04d}_"
FDIR = f"{DIR}/{KIND}"
Path(FDIR).mkdir(exist_ok=True, parents=True)
fname = f"{FDIR}/{prefix}sorted_colors_circle.png"
if KIND in ("lex", "alpha", "abc"):
preds = np.array(colors)
elif KIND == "umap":
# from umap import UMAP
from cuml import UMAP
# Use UMAP to create a 1D representation
reducer = UMAP(
n_components=1,
n_neighbors=250,
min_dist=1e-2,
metric="euclidean",
random_state=SEED,
negative_sample_rate=2,
)
embedding = reducer.fit_transform(np.array(rgb_values))
# Sort colors by the 1D representation
preds = embedding[:, 0]
del reducer, embedding
elif KIND in ("cielab", "lab", "ciede2000"):
from skimage.color import deltaE_ciede2000, rgb2lab
# CIELAB
# Convert RGB values to CIELAB
lab_values = rgb2lab([rgb_values])
# Reference color for sorting (can be the first color or any other reference point)
reference_color = lab_values[0]
# Compute CIEDE2000 distances of all colors to the reference color
distances = [deltaE_ciede2000(reference_color, color) for color in lab_values]
# Sort colors by their CIEDE2000 distance to the reference color
# preds = distances).flatten() # awful
lab_values_flat = lab_values.reshape(-1, 3)
# Sort colors based on the L* value in the CIELAB space
# 0 corresponds to the L* channel
preds = lab_values_flat[:, 0]
elif KIND == "hsv":
from matplotlib.colors import rgb_to_hsv
# Convert RGB values to HSV
hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values])
# Sort colors based on the hue value
# 0 corresponds to the hue component
preds = hsv_values[:, 0]
else:
raise ValueError(f"Unknown kind: {KIND}")
sorted_indices = np.argsort(preds)
# Save the sorted indices to disk
# if (KIND == "umap") or (KIND != "umap"):
PDIR = f"scripts/{KIND}"
Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{SEED:06d}.npy"
np.save(file_path, preds.ravel())
print(f"Predictions saved to {file_path}")
# Sort colors by the 1D representation
sorted_colors = [colors[i] for i in sorted_indices]
plot_preds(
preds,
np.array(rgb_values),
fname,
roll=False,
dpi=DPI,
inner_radius=INNER_RADIUS,
figsize=(SIZE, SIZE),
fsize=FONTSIZE,
label=f"{KIND.upper()}",
)
print(f"saved {fname}")
HILBERT = False
if HILBERT:
# Create Hilbert curve
# We'll set the order such that the number of positions is greater than or equal to the number of colors
hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors))))
hilbert_curve = HilbertCurve(hilbert_order, 2)
# Create an image for visualization
image_size = 2**hilbert_order
image = np.ones((image_size, image_size, 3))
for i, color in enumerate(sorted_colors):
# Convert linear index to Hilbert coordinates
coords = hilbert_curve.point_from_distance(i)
image[coords[1], coords[0]] = mcolors.to_rgb(color)
# annotation in upper right
# Display the image
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
ax.imshow(image)
ax.annotate(
f"{KIND.upper()}",
(1.0, 1.0),
ha="right",
va="top",
size=FONTSIZE,
xycoords="axes fraction",
)
ax.axis("off")
ax.set_aspect("equal")
fig.tight_layout()
fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png"
fig.savefig(
fname,
dpi=DPI,
transparent=True,
# bbox_inches="tight",
# pad_inches=0
)
print(f"Saved {fname}")

View File

@ -1,6 +1,6 @@
# import matplotlib.patches as patches
from pathlib import Path
from typing import Union
from typing import Union, Tuple
import matplotlib.patches as patches
import matplotlib.pyplot as plt
@ -61,7 +61,15 @@ def create_circle(
def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6)
preds: torch.Tensor | np.ndarray,
rgb_values,
fname: str,
roll: bool = False,
inner_radius: float = 1 / 3,
dpi: int = 300,
figsize: Tuple[float] = (6, 6),
fsize: int = 0,
label: str = "",
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
@ -108,20 +116,28 @@ def plot_preds(
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(-radius, radius)
ax.set_ylim(0, radius)
# Overlay white circle
inner_radius = 1 / 3
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
if fsize > 0.0:
center = (0, 0)
ax.annotate(
label,
center,
ha="center",
va="center",
size=fsize,
color="black",
)
fig.tight_layout(pad=0)
plt.savefig(
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
plt.close()

View File

@ -56,10 +56,10 @@ umap:
done
sort_umap:
python scripts/sortcolor.py -s umap --dpi 300 --seed 21
python baseline.py -s umap --dpi 300 --seed 21
parallel_umap:
parallel -j 4 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
parallel -j 4 python baseline.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
parallel_check:
parallel -j 4 python check.py -v ::: $$(seq 0 99)

View File

@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401
# consistency of randomly sampled experiments.
seed(19920921)
NUM_JOBS = 50
NUM_JOBS = 33
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
@ -38,7 +38,7 @@ seeds = list(range(21, 1992))
optimizers = [
# "Adagrad",
"Adam",
# "SGD",
"SGD",
# "AdamW",
# "LBFGS",
# "RAdam",
@ -81,6 +81,7 @@ python newmain.py fit \
--model.bias true \
--model.loop true \
--model.transform tanh \
--model.dropout 0 \
--trainer.min_epochs 10 \
--trainer.max_epochs {me} \
--trainer.log_every_n_steps 3 \

View File

@ -215,10 +215,12 @@ def plot_preds(
rgb_values,
fname: str,
roll: bool = False,
dpi: int = 150,
inner_radius: float = 1 / 3,
figsize=(3, 3),
dpi: int = 300,
figsize=(6, 6),
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll: