Compare commits
3 Commits
main
...
bd579913d9
Author | SHA1 | Date | |
---|---|---|---|
|
bd579913d9 | ||
|
8522aa0386 | ||
|
6321157df1 |
170
baseline.py
Normal file
170
baseline.py
Normal file
@ -0,0 +1,170 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from hilbertcurve.hilbertcurve import HilbertCurve
|
||||
|
||||
from check import plot_preds
|
||||
|
||||
# Extract XKCD colors
|
||||
colors = list(mcolors.XKCD_COLORS.keys())
|
||||
rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors]
|
||||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting")
|
||||
parser.add_argument("--seed", type=int, default=21, help="seed for UMAP")
|
||||
parser.add_argument("--dpi", type=int, default=300, help="dpi for saving")
|
||||
parser.add_argument("--size", type=float, default=6.0, help="size of figure")
|
||||
parser.add_argument(
|
||||
"--fontsize",
|
||||
type=float,
|
||||
default=0,
|
||||
help="fontsize of annotation (default: 0 = None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--radius", type=float, default=1 / 2, help="inner radius of circle"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
KIND = args.sort_by
|
||||
SEED = args.seed
|
||||
DPI = args.dpi
|
||||
SIZE = args.size
|
||||
FONTSIZE = args.fontsize
|
||||
INNER_RADIUS = args.radius
|
||||
DIR = "/teamspace/studios/colors/umap"
|
||||
|
||||
|
||||
prefix = ""
|
||||
if KIND == "umap":
|
||||
prefix = f"{SEED:04d}_"
|
||||
FDIR = f"{DIR}/{KIND}"
|
||||
Path(FDIR).mkdir(exist_ok=True, parents=True)
|
||||
fname = f"{FDIR}/{prefix}sorted_colors_circle.png"
|
||||
|
||||
|
||||
|
||||
if KIND in ("lex", "alpha", "abc"):
|
||||
preds = np.array(colors)
|
||||
|
||||
elif KIND == "umap":
|
||||
# from umap import UMAP
|
||||
from cuml import UMAP
|
||||
|
||||
# Use UMAP to create a 1D representation
|
||||
reducer = UMAP(
|
||||
n_components=1,
|
||||
n_neighbors=250,
|
||||
min_dist=1e-2,
|
||||
metric="euclidean",
|
||||
random_state=SEED,
|
||||
negative_sample_rate=2,
|
||||
)
|
||||
embedding = reducer.fit_transform(np.array(rgb_values))
|
||||
|
||||
# Sort colors by the 1D representation
|
||||
preds = embedding[:, 0]
|
||||
del reducer, embedding
|
||||
|
||||
elif KIND in ("cielab", "lab", "ciede2000"):
|
||||
from skimage.color import deltaE_ciede2000, rgb2lab
|
||||
|
||||
# CIELAB
|
||||
# Convert RGB values to CIELAB
|
||||
lab_values = rgb2lab([rgb_values])
|
||||
|
||||
# Reference color for sorting (can be the first color or any other reference point)
|
||||
reference_color = lab_values[0]
|
||||
|
||||
# Compute CIEDE2000 distances of all colors to the reference color
|
||||
distances = [deltaE_ciede2000(reference_color, color) for color in lab_values]
|
||||
|
||||
# Sort colors by their CIEDE2000 distance to the reference color
|
||||
# preds = distances).flatten() # awful
|
||||
lab_values_flat = lab_values.reshape(-1, 3)
|
||||
# Sort colors based on the L* value in the CIELAB space
|
||||
# 0 corresponds to the L* channel
|
||||
preds = lab_values_flat[:, 0]
|
||||
|
||||
elif KIND == "hsv":
|
||||
from matplotlib.colors import rgb_to_hsv
|
||||
|
||||
# Convert RGB values to HSV
|
||||
hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values])
|
||||
|
||||
# Sort colors based on the hue value
|
||||
# 0 corresponds to the hue component
|
||||
preds = hsv_values[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown kind: {KIND}")
|
||||
|
||||
sorted_indices = np.argsort(preds)
|
||||
|
||||
# Save the sorted indices to disk
|
||||
# if (KIND == "umap") or (KIND != "umap"):
|
||||
PDIR = f"scripts/{KIND}"
|
||||
Path(PDIR).mkdir(parents=True, exist_ok=True)
|
||||
file_path = f"{PDIR}/{SEED:06d}.npy"
|
||||
np.save(file_path, preds.ravel())
|
||||
print(f"Predictions saved to {file_path}")
|
||||
|
||||
# Sort colors by the 1D representation
|
||||
sorted_colors = [colors[i] for i in sorted_indices]
|
||||
|
||||
plot_preds(
|
||||
preds,
|
||||
np.array(rgb_values),
|
||||
fname,
|
||||
roll=False,
|
||||
dpi=DPI,
|
||||
inner_radius=INNER_RADIUS,
|
||||
figsize=(SIZE, SIZE),
|
||||
fsize=FONTSIZE,
|
||||
label=f"{KIND.upper()}",
|
||||
)
|
||||
print(f"saved {fname}")
|
||||
|
||||
HILBERT = False
|
||||
|
||||
if HILBERT:
|
||||
# Create Hilbert curve
|
||||
# We'll set the order such that the number of positions is greater than or equal to the number of colors
|
||||
hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors))))
|
||||
hilbert_curve = HilbertCurve(hilbert_order, 2)
|
||||
|
||||
# Create an image for visualization
|
||||
image_size = 2**hilbert_order
|
||||
image = np.ones((image_size, image_size, 3))
|
||||
|
||||
for i, color in enumerate(sorted_colors):
|
||||
# Convert linear index to Hilbert coordinates
|
||||
coords = hilbert_curve.point_from_distance(i)
|
||||
image[coords[1], coords[0]] = mcolors.to_rgb(color)
|
||||
|
||||
# annotation in upper right
|
||||
# Display the image
|
||||
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
|
||||
ax.imshow(image)
|
||||
ax.annotate(
|
||||
f"{KIND.upper()}",
|
||||
(1.0, 1.0),
|
||||
ha="right",
|
||||
va="top",
|
||||
size=FONTSIZE,
|
||||
xycoords="axes fraction",
|
||||
)
|
||||
ax.axis("off")
|
||||
ax.set_aspect("equal")
|
||||
fig.tight_layout()
|
||||
fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png"
|
||||
fig.savefig(
|
||||
fname,
|
||||
dpi=DPI,
|
||||
transparent=True,
|
||||
# bbox_inches="tight",
|
||||
# pad_inches=0
|
||||
)
|
||||
print(f"Saved {fname}")
|
30
check.py
30
check.py
@ -1,6 +1,6 @@
|
||||
# import matplotlib.patches as patches
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
@ -61,7 +61,15 @@ def create_circle(
|
||||
|
||||
|
||||
def plot_preds(
|
||||
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6)
|
||||
preds: torch.Tensor | np.ndarray,
|
||||
rgb_values,
|
||||
fname: str,
|
||||
roll: bool = False,
|
||||
inner_radius: float = 1 / 3,
|
||||
dpi: int = 300,
|
||||
figsize: Tuple[float] = (6, 6),
|
||||
fsize: int = 0,
|
||||
label: str = "",
|
||||
):
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().cpu().numpy()
|
||||
@ -108,20 +116,28 @@ def plot_preds(
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
radius = 1
|
||||
ax.set_ylim(-radius, radius)
|
||||
ax.set_ylim(0, radius)
|
||||
|
||||
# Overlay white circle
|
||||
inner_radius = 1 / 3
|
||||
circle = patches.Circle(
|
||||
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
|
||||
if fsize > 0.0:
|
||||
center = (0, 0)
|
||||
ax.annotate(
|
||||
label,
|
||||
center,
|
||||
ha="center",
|
||||
va="center",
|
||||
size=fsize,
|
||||
color="black",
|
||||
)
|
||||
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.savefig(
|
||||
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
|
||||
)
|
||||
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
|
4
makefile
4
makefile
@ -56,10 +56,10 @@ umap:
|
||||
done
|
||||
|
||||
sort_umap:
|
||||
python scripts/sortcolor.py -s umap --dpi 300 --seed 21
|
||||
python baseline.py -s umap --dpi 300 --seed 21
|
||||
|
||||
parallel_umap:
|
||||
parallel -j 4 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
|
||||
parallel -j 4 python baseline.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
|
||||
|
||||
parallel_check:
|
||||
parallel -j 4 python check.py -v ::: $$(seq 0 99)
|
||||
|
@ -8,7 +8,7 @@ from lightning_sdk import Machine, Studio # noqa: F401
|
||||
# consistency of randomly sampled experiments.
|
||||
seed(19920921)
|
||||
|
||||
NUM_JOBS = 50
|
||||
NUM_JOBS = 33
|
||||
|
||||
# reference to the current studio
|
||||
# if you run outside of Lightning, you can pass the Studio name
|
||||
@ -38,7 +38,7 @@ seeds = list(range(21, 1992))
|
||||
optimizers = [
|
||||
# "Adagrad",
|
||||
"Adam",
|
||||
# "SGD",
|
||||
"SGD",
|
||||
# "AdamW",
|
||||
# "LBFGS",
|
||||
# "RAdam",
|
||||
@ -81,6 +81,7 @@ python newmain.py fit \
|
||||
--model.bias true \
|
||||
--model.loop true \
|
||||
--model.transform tanh \
|
||||
--model.dropout 0 \
|
||||
--trainer.min_epochs 10 \
|
||||
--trainer.max_epochs {me} \
|
||||
--trainer.log_every_n_steps 3 \
|
||||
|
@ -215,10 +215,12 @@ def plot_preds(
|
||||
rgb_values,
|
||||
fname: str,
|
||||
roll: bool = False,
|
||||
dpi: int = 150,
|
||||
inner_radius: float = 1 / 3,
|
||||
figsize=(3, 3),
|
||||
dpi: int = 300,
|
||||
figsize=(6, 6),
|
||||
):
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().cpu().numpy()
|
||||
sorted_inds = np.argsort(preds.ravel())
|
||||
colors = rgb_values[sorted_inds, :3]
|
||||
if roll:
|
||||
|
Loading…
Reference in New Issue
Block a user