Compare commits

...

81 Commits

Author SHA1 Message Date
Michael Pilosov, PhD
f0a4c940af bg filling 2024-02-14 16:05:39 +00:00
Michael Pilosov, PhD
04ff370bbb wip 2024-02-14 04:55:42 +00:00
Michael Pilosov, PhD
90f08a5edb paths 2024-02-14 04:31:27 +00:00
Michael Pilosov, PhD
831f17a9b4 tweaks + path changes 2024-02-13 06:55:21 +00:00
Michael Pilosov, PhD
698f6f1b51 seeds 2024-02-13 06:46:52 +00:00
Michael Pilosov, PhD
13dd754ba6 nice set of params 2024-02-13 06:46:18 +00:00
Michael Pilosov, PhD
a97055f282 GNU parallel 2024-02-13 05:35:54 +00:00
Michael Pilosov, PhD
d1b963e05d minimal requirements. 2024-02-13 05:35:44 +00:00
Michael Pilosov, PhD
f3bc1e90d3 ignore npy 2024-02-13 05:28:26 +00:00
Michael Pilosov, PhD
b95cb8df3d cuda libs 2024-02-13 05:28:18 +00:00
Michael Pilosov, PhD
527cc285f3 drop in cuml 2024-02-13 05:21:39 +00:00
Michael Pilosov, PhD
95a6adac91 lint 2024-02-13 03:50:44 +00:00
Michael Pilosov, PhD
447b773d03 no annotate, update target 2024-02-13 03:42:37 +00:00
Michael Pilosov, PhD
00cba285c8 minor tweaks, check annotations 2024-02-13 03:34:31 +00:00
Michael Pilosov, PhD
9953327d30 copy over color files, use new plotting 2024-02-13 03:26:37 +00:00
Michael Pilosov, PhD
0ae8414481 syntax, no diff 2024-01-28 09:54:38 +00:00
Michael Pilosov, PhD
2af491c324 tweak metric 2024-01-28 09:43:36 +00:00
Michael Pilosov, PhD
5687f30818 back to "correct" metric 2024-01-28 08:32:05 +00:00
Michael Pilosov, PhD
5def982f12 recreate iris looking result. weird though. 2024-01-28 08:09:13 +00:00
Michael Pilosov, PhD
a9e772f34e restrict experiment, add loop option + dropout 2024-01-28 07:31:49 +00:00
Michael Pilosov
697030b7df partial unsupervision 2024-01-27 23:34:25 -07:00
Michael Pilosov
f450611ce3 pin requirements 2024-01-27 23:33:05 -07:00
Michael Pilosov, PhD
0c4990e98c remove depth = 8 and switch to relu 2024-01-28 05:42:04 +00:00
Michael Pilosov, PhD
05dd4e29ce sequential instead of loopback 2024-01-28 05:26:12 +00:00
Michael Pilosov, PhD
d318480b7c circle norm bad for supervise? 2024-01-28 02:58:51 +00:00
Michael Pilosov, PhD
1c116f3f12 benchmark supervised again 2024-01-28 02:54:30 +00:00
Michael Pilosov, PhD
c7ffd09fb4 lr scheduler 2024-01-28 02:44:36 +00:00
Michael Pilosov, PhD
4342a54cc8 anchors and more train data, exp w batch 2024-01-28 02:43:35 +00:00
Michael Pilosov, PhD
248d1a72f9 re-anchor, pretty meh results in-batch 2024-01-28 01:52:21 +00:00
Michael Pilosov, PhD
9e4861a272 try unsupervised again, but with 10k random samples 2024-01-28 01:33:10 +00:00
Michael Pilosov, PhD
e5b6f287a3 xkcd colors may be too few to learn from. need 10x 2024-01-28 01:32:56 +00:00
Michael Pilosov, PhD
865e7f5104 supervised questionable 2024-01-28 01:25:10 +00:00
Michael Pilosov, PhD
b6d9f94d8e tracked down the losses bug 2024-01-28 00:51:42 +00:00
Michael Pilosov, PhD
953488be4c remove modulo from 3-space 2024-01-27 23:52:04 +00:00
Michael Pilosov, PhD
0e561aae4c big batch 2024-01-27 22:38:01 +00:00
Michael Pilosov, PhD
30470f13bc unsupervised results looking stunning 2024-01-27 22:26:41 +00:00
Michael Pilosov, PhD
70f56ff9f0 cleanup temp files from interrupts 2024-01-27 22:22:49 +00:00
Michael Pilosov, PhD
3adcc9779a type hint, rename vars 2024-01-27 22:17:23 +00:00
Michael Pilosov, PhD
a49f166252 seems to learn (supervised) consistently now 2024-01-27 21:55:45 +00:00
Michael Pilosov, PhD
a44580a15b depth of 1 led to consistently bad learning 2024-01-27 21:50:33 +00:00
Michael Pilosov, PhD
eee8a8b0ba update page 2024-01-27 21:10:42 +00:00
Michael Pilosov, PhD
6260e7fdd3 restrict optimizers 2024-01-27 21:08:44 +00:00
Michael Pilosov, PhD
467b3f7e57 entirely supervised 2024-01-27 21:03:32 +00:00
Michael Pilosov, PhD
1da8d3194a first one w weighted alpha 2024-01-27 20:59:16 +00:00
Michael Pilosov, PhD
b5d9e725b3 allow for mix of supervised and not with alpha 2024-01-27 19:48:04 +00:00
Michael Pilosov, PhD
721993d9e5 switch to unsupervised 2024-01-27 19:43:20 +00:00
Michael Pilosov, PhD
b9d334e49a weird setup, but sometimes possible to learn labels 2024-01-27 19:38:40 +00:00
Michael Pilosov, PhD
07b4e548e2 lets make sure this can learn via supervision first 2024-01-27 18:58:06 +00:00
Michael Pilosov, PhD
948bc31861 cull experimental range 2024-01-27 17:58:36 +00:00
Michael Pilosov, PhD
fff2b88fa1 job updates 2024-01-27 10:14:52 +00:00
Michael Pilosov, PhD
49e6260346 experiment loggin setup 2024-01-27 09:50:20 +00:00
Michael Pilosov, PhD
6d40d39097 save config to out 2024-01-27 09:39:25 +00:00
Michael Pilosov, PhD
3b700aee70 optimizer search 2024-01-27 09:34:10 +00:00
Michael Pilosov, PhD
1e818aa977 remove s loss entirely 2024-01-27 09:29:22 +00:00
Michael Pilosov, PhD
1ea29ba11e some missing params 2024-01-27 09:26:19 +00:00
Michael Pilosov, PhD
e1ac3211b9 use lightning CLI everywhere 2024-01-27 09:14:14 +00:00
Michael Pilosov, PhD
012c7b7c68 add validation step 2024-01-27 07:48:04 +00:00
Michael Pilosov, PhD
88c8cde9f6 total overhaul of model 2024-01-27 07:27:57 +00:00
Michael Pilosov, PhD
b9b6ee7727 disable min/max hints 2024-01-27 05:41:41 +00:00
Michael Pilosov, PhD
e61543299d cleanup 2024-01-27 05:41:31 +00:00
Michael Pilosov, PhD
2ec2d1f368 requirements 2024-01-27 05:20:29 +00:00
Michael Pilosov, PhD
50628e594a fix image gen 2024-01-26 03:02:33 +00:00
Michael Pilosov, PhD
a02a662b6f image previews 2024-01-26 00:24:21 +00:00
Michael Pilosov, PhD
7ce24b0cd3 plotting args 2024-01-25 06:18:10 +00:00
Michael Pilosov
947a7d4a56 bugfix 2024-01-16 05:47:46 +00:00
Michael Pilosov
0366b5d0f1 try anchoring secondary colors 2024-01-16 05:19:54 +00:00
Michael Pilosov
70ecd7d7db another attempt 2024-01-16 04:38:13 +00:00
Michael Pilosov
62e0764a70 plotting updates 2024-01-16 03:05:26 +00:00
Michael Pilosov
72a1ad2971 plotting improvements 2024-01-15 21:39:44 +00:00
Michael Pilosov
6899320927 fine tune image 2024-01-15 21:05:31 +00:00
Michael Pilosov
999d73f7ab prep for search 2024-01-15 20:11:54 +00:00
Michael Pilosov
e709d8f34f lr 2024-01-15 19:45:40 +00:00
Michael Pilosov
686e096b97 saturate T4 GPU 2024-01-15 19:26:59 +00:00
Michael Pilosov
1f7d4c1890 use millions of colors 2024-01-15 19:18:28 +00:00
Michael Pilosov
5ed305fe34 dpi argument 2024-01-15 19:02:26 +00:00
Michael Pilosov
7461406cf2 gpu updates 2024-01-15 07:00:50 +00:00
Michael Pilosov
0934ca0aed parameterize network width 2024-01-15 06:35:48 +00:00
Michael Pilosov
a8b723f021 prep search for arch 2024-01-15 06:31:17 +00:00
Michael Pilosov
1f96c65b21 hsv image inclusion 2024-01-15 05:51:52 +00:00
Michael Pilosov
b5bb3fe3df fixed alpha 2024-01-15 05:32:17 +00:00
Michael Pilosov
a8d62bfca0 new separation loss: absolute vals 2024-01-15 05:30:45 +00:00
30 changed files with 1952 additions and 225 deletions

3
.gitignore vendored
View File

@ -5,3 +5,6 @@ out/
.sw[opqr]
*.tar.gz
.pat
out*
.lr*
*.npy

View File

@ -1,11 +1,11 @@
from pathlib import Path
import pytorch_lightning as pl
from lightning import Callback
from check import create_circle
class SaveImageCallback(pl.Callback):
class SaveImageCallback(Callback):
def __init__(self, save_interval=1, final_dir: str = None):
self.save_interval = save_interval
self.final_dir = final_dir

View File

@ -1,4 +1,7 @@
# import matplotlib.patches as patches
from typing import Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
@ -9,7 +12,7 @@ from model import ColorTransformerModel
# import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True):
def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor)
@ -35,52 +38,87 @@ def make_image(ckpt: str, fname: str, color=True):
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{fname}.png", dpi=300)
plt.savefig(f"{fname}.png", **kwargs)
def create_circle(ckpt: str, fname: str):
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
M = ckpt
rgb_tensor, _ = extract_colors()
preds = M(rgb_tensor)
plot_preds(preds, fname=fname)
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds(preds, fname: str, roll: bool = False):
rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().numpy()
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds]
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))[0][0]
colors = np.roll(colors, -white_idx, axis=0)
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True))
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0)
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
height=1,
width=width,
edgecolor=colors[i],
linewidth=0.25,
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
align="edge",
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
fig.tight_layout()
plt.savefig(f"{fname}.png", dpi=150)
radius = 1
ax.set_ylim(-radius, radius)
ax.set_xlim(-radius, radius)
# Overlay white circle
inner_radius = 1 / 3
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
fig.tight_layout(pad=0)
plt.savefig(
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
)
plt.close()
@ -91,18 +129,22 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1])
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image."
)
parser.add_argument("--figsize", type=int, default=3, help="Figure size")
args = parser.parse_args()
versions = args.version
for v in versions:
name = f"out/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/colors-refactor-secondary/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name)
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False)
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

235
color_poster.py Normal file
View File

@ -0,0 +1,235 @@
from typing import List
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
# set font by default to courier new
plt.rcParams["font.family"] = "PT Mono"
# # sort by the proximity to the colors in viridis
# # Importing necessary functions
# # Calculate the proximity of each color in XKCD_COLORS to the viridis colormap
# def calculate_proximity_to_viridis(color):
# rgb_triple = mcolors.to_rgb(mcolors.XKCD_COLORS[color])
# distances = [(sum((a - b)**2 for a, b in zip(rgb_triple, viridis(i))), i) for i in range(256)]
# _, closest_viridis_value = min(distances, key=lambda x: x[0]) # Find the viridis color with the minimum distance
# return closest_viridis_value / 255 # Normalize to range (0, 1)
# # Calculate the proximity values for each color
# proximity_values = {color: calculate_proximity_to_viridis(color) for color in colors}
# # Sort the colors based on their proximity values
# sorted_colors = sorted(proximity_values.keys(), key=lambda x: proximity_values[x])
def create_color_calibration_image(
colors, ppi: List[int] = [100], index=0, solid_capstyle="butt", antialiased=True
):
first_color = colors[0]
last_color = colors[-1]
print(f"Processing color range: {first_color} to {last_color}")
# Conversion factor: 1 inch = dpi pixels
vert_space = 1 / 2 # inches
fontsize = 12 # points
# Figure settings
fig_width = 4.0 # inches
fig_height = len(colors) * vert_space
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
# plot a vertical black rectangle between x=3 and x=4
ax.axvspan(0.25, 0.75, facecolor="black")
# ax.axvspan(-0.325-0.175, -0.125, facecolor="black")
# Loop through each color
if index % 2 == 0:
skip = -1
else:
skip = 1
for idx, color in enumerate(colors[::skip]):
# print(color)
y_position = 0.5 - 0.125 + idx * vert_space # Offset each color by 0.3 inches
# Draw color name
rgb_triple = mcolors.to_rgb(mcolors.XKCD_COLORS[color])
# round to 4 decimal places
rgb_triple = tuple(round(x, 4) for x in rgb_triple)
# format as string with fixed decimal places
rgb_triple = ", ".join([f"{x:1.4f}" for x in rgb_triple])
hex_code = mcolors.to_hex(mcolors.XKCD_COLORS[color])
ax.text(
1.0,
y_position,
color.replace("xkcd:", ""),
va="center",
fontsize=fontsize,
# bbox=dict(facecolor='gray', alpha=0.21),
)
ax.text(
1.25,
y_position - 1.5 * fontsize / 72,
f"{hex_code}\n({rgb_triple})",
va="center",
fontsize=6,
# bbox=dict(facecolor='gray', alpha=0.33),
)
# ax.text(
# 1.125,
# y_position - 2 * fontsize / 72,
# f"{hex_code}",
# va="center",
# fontsize=fontsize * 0.6,
# )
# Draw color square
rect_height = 0.25
rect_width = 0.25
square_x_start = 0.75 - rect_width / 2 # Offset from the left
square_y_start = y_position - rect_height # + 0.25
ax.add_patch(
plt.Rectangle(
(square_x_start, square_y_start), rect_width, rect_height, fc=color
)
)
# Draw lines with varying stroke sizes
line_x_start = 0 # Offset from the left
line_length = 0.5
line_y_start = y_position - rect_height - 0.075
line_widths = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2.0, 3.0][::-1]
for width in line_widths:
ax.plot(
[line_x_start, line_x_start + line_length],
[line_y_start, line_y_start],
linewidth=width,
color=color,
antialiased=antialiased,
solid_capstyle=solid_capstyle,
)
line_y_start += 0.05
# # now repeat but vertical of height 1
# line_x_start = 3.125 - 0.05 # Offset from the left
# line_y_start = y_position + dpi / 960
# for width in line_widths:
# ax.plot(
# [line_x_start, line_x_start],
# [line_y_start, line_y_start - 0.5],
# linewidth=width,
# color=color,
# antialiased=True,
# )
# ax.plot(
# [0.5 + line_x_start, 0.5 + line_x_start],
# [line_y_start, line_y_start - 0.5],
# linewidth=width,
# color=color,
# antialiased=True,
# )
# line_x_start += 0.05
# Save the image
# Remove axes
ax.axis("off")
# ax.set_aspect("equal")
# plt.tight_layout(pad=0)
ax.set_ylim([0, fig_height])
# ax.set_xlim([0, fig_width])
ax.set_xlim([0, fig_width])
# pad = 0.108
pad = 0
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
output_paths = []
for _dpi in ppi:
_out_path = f"/tmp/color_calibration-{_dpi}_{index:02d}.png"
plt.savefig(_out_path, pad_inches=pad, dpi=_dpi)
output_paths.append(_out_path)
# plt.show()
plt.close()
return output_paths
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument(
"--rows", type=int, default=73, help="Number of entries per column"
)
parser.add_argument(
"--dir", type=str, default="/Volumes/TMP/tests", help="Directory to save images"
)
parser.add_argument(
"-k",
"--kind",
type=str,
nargs="+",
default=["hsv", "lex", "lab", "umap"],
help="Kinds of sorting",
)
parser.add_argument(
"--ppi", action="append", type=int, default=[300], help="Pixels per inch"
)
parser.add_argument("--aliased", action="store_true", help="Disable antialiasing")
parser.add_argument(
"--capstyle", type=str, default="butt", help="Capstyle of lines"
)
args = parser.parse_args()
COLUMN_LENGTH = args.rows
KINDS = args.kind
PPIS = args.ppi
DIR = args.dir
ANTIALIASED = not args.aliased
CAPSTYLE = args.capstyle
# COLUMN_LENGTH = 73 # results in 13 unfiltered columns (perfect)
# COLUMN_LENGTH = (
# 106 # results in 9 unfiltered columns (last one short), square-ish image
# )
OMITTED_COLORS = [
# "black",
# "white",
# "poop",
# "poo brown",
# "shit",
# "shit brown",
]
# OMITTED_COLORS = list(map(lambda s: f"xkcd:{s}", OMITTED_COLORS))
# KIND = "hsv" # choose from umap, hsv
for KIND in KINDS:
colors = list(mcolors.XKCD_COLORS.keys())
sorted_indices = np.load(f"scripts/{KIND}_sorted_indices.npy")
sorted_colors = [colors[idx] for idx in sorted_indices]
colors = sorted_colors
colors = [c for c in colors if c not in OMITTED_COLORS]
print(f"Total number of colors: {len(colors)}")
chunks = [
colors[i : i + COLUMN_LENGTH] for i in range(0, len(colors), COLUMN_LENGTH)
]
for idx, color_part in enumerate(chunks):
image_path = create_color_calibration_image(
colors=color_part,
ppi=PPIS,
index=idx,
antialiased=ANTIALIASED,
solid_capstyle=CAPSTYLE,
)
os.system(f"identify {image_path[0]}")
for PPI in PPIS:
# use imagemagick to stitch together the images horizontally
os.system(
f"convert +append /tmp/color_calibration-{PPI}_*.png /tmp/color_calibration-{PPI}.png"
)
os.system(f"rm /tmp/color_calibration-{PPI}_*")
print(f"Final image saved to /tmp/color_calibration-{PPI}.png")
os.system(
f"mkdir -p {DIR} && cp /tmp/color_calibration-{PPI}.png {DIR}/xkcd_{COLUMN_LENGTH}_{KIND}_{PPI}.png"
)
print(f"Copied to {DIR}/xkcd_{COLUMN_LENGTH}_{KIND}_{PPI}.png")

View File

@ -4,39 +4,42 @@ from torch.utils.data import DataLoader, TensorDataset
from utils import extract_colors, preprocess_data
def create_dataloader(N: int = 50, **kwargs):
rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
def create_random_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader
def create_gray_supplement(N: int = 50):
def create_gray_supplement(N: int = 50, skip: bool = True):
linear_space = torch.linspace(0, 1, N)
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
gray_tensor = preprocess_data(gray_tensor)
gray_tensor = preprocess_data(gray_tensor, skip=skip)
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
def create_named_dataloader(N: int = 0, **kwargs):
def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
if N > 0:
dataset_with_names += create_gray_supplement(N)
dataset_with_names += create_gray_supplement(N, skip=skip)
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names
if __name__ == "__main__":
batch_size = 4
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)
train_dataloader = create_random_dataloader(
N=1e6, batch_size=batch_size, shuffle=True
)
print(len(train_dataloader.dataset))
train_dataloader_with_names = create_named_dataloader(
batch_size=batch_size, shuffle=True
)

103
datamodule.py Normal file
View File

@ -0,0 +1,103 @@
import lightning as L
import torch
from matplotlib.colors import rgb_to_hsv
from torch.utils.data import DataLoader
from utils import extract_colors, preprocess_data
class ColorDataModule(L.LightningDataModule):
def __init__(
self,
val_size: int = 10_000,
train_size=0,
batch_size: int = 32,
num_workers: int = 3,
):
super().__init__()
self.val_size = val_size
self.train_size = train_size
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
# no state. called from main process.
pass
@classmethod
def get_hue(cls, v: torch.Tensor) -> torch.Tensor:
return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32)
@classmethod
def get_random_colors(cls, size: int):
train_rgb = torch.rand((int(size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
return [(c, cls.get_hue(c)) for c in train_rgb]
@classmethod
def get_xkcd_colors(cls):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
# return [
# (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
# for i in range(len(rgb_tensor))
# ]
return [(c, cls.get_hue(c)) for c in rgb_tensor]
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.color_val = self.get_random_colors(self.val_size)
if self.train_size > 0:
self.color_train = self.get_random_colors(self.train_size)
else:
self.color_train = self.get_xkcd_colors()
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.color_test = self.get_random_colors(self.val_size)
if stage == "predict": # for visualizing
self.color_predict = self.get_xkcd_colors()
def train_dataloader(self):
return DataLoader(
self.color_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.color_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
self.color_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def predict_dataloader(self):
return DataLoader(
self.color_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
pass
if __name__ == "__main__":
cdm = ColorDataModule()
cdm.setup("train")
print(cdm)

View File

@ -1,25 +0,0 @@
,batch_size,alpha,learning_rate
0,32.0,0.3,0.0001
1,32.0,0.3,0.01
2,32.0,0.9,1e-06
3,32.0,0.7,0.001
4,64.0,0.5,0.001
5,64.0,0.1,1e-06
6,32.0,0.1,0.001
7,128.0,0.5,1e-06
8,128.0,0.7,0.001
9,128.0,0.9,1e-05
10,128.0,0.1,1e-06
11,128.0,0.3,1e-06
12,64.0,0.3,0.01
13,64.0,0.1,1e-06
14,128.0,0.5,0.001
15,32.0,0.3,1e-05
16,32.0,0.7,1e-06
17,32.0,0.3,1e-06
18,64.0,0.3,0.0001
19,64.0,0.3,1e-06
20,128.0,0.5,1e-05
21,32.0,0.1,0.01
22,64.0,0.1,1e-05
23,64.0,0.3,0.001
1 batch_size alpha learning_rate
2 0 32.0 0.3 0.0001
3 1 32.0 0.3 0.01
4 2 32.0 0.9 1e-06
5 3 32.0 0.7 0.001
6 4 64.0 0.5 0.001
7 5 64.0 0.1 1e-06
8 6 32.0 0.1 0.001
9 7 128.0 0.5 1e-06
10 8 128.0 0.7 0.001
11 9 128.0 0.9 1e-05
12 10 128.0 0.1 1e-06
13 11 128.0 0.3 1e-06
14 12 64.0 0.3 0.01
15 13 64.0 0.1 1e-06
16 14 128.0 0.5 0.001
17 15 32.0 0.3 1e-05
18 16 32.0 0.7 1e-06
19 17 32.0 0.3 1e-06
20 18 64.0 0.3 0.0001
21 19 64.0 0.3 1e-06
22 20 128.0 0.5 1e-05
23 21 32.0 0.1 0.01
24 22 64.0 0.1 1e-05
25 23 64.0 0.3 0.001

BIN
hsv.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 328 KiB

After

Width:  |  Height:  |  Size: 2.7 MiB

8
hsv.py
View File

@ -7,8 +7,12 @@ from utils import extract_colors
if __name__ == "__main__":
rgb_tensor, _ = extract_colors()
xkcd_rgb = rgb_tensor.numpy()
# xkcd_rgb = np.random.rand(1000, 3)
xkcd_hsv = rgb_to_hsv(xkcd_rgb)
plot_preds(xkcd_hsv[:, 0], fname="hsv")
rgb = np.eye(3)
plot_preds(
xkcd_hsv[:, 0], xkcd_rgb, fname="hsv", roll=True, dpi=300, figsize=(6, 6)
)
rgb = np.vstack([np.eye(3), np.eye(3) + np.eye(3)[:, [1, 2, 0]]])
print("Pure RGB in Hue-Space:")
print(rgb)
print(rgb_to_hsv(rgb)[:, 0])

86
hsv1.txt Normal file
View File

@ -0,0 +1,86 @@
# lightning.pytorch==2.1.3
seed_everything: 1387
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
logger: null
callbacks:
- class_path: callbacks.SaveImageCallback
init_args:
save_interval: 0
final_dir: out
fast_dev_run: false
max_epochs: 10
min_epochs: 10
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 50
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 3
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
transform: tanh
width: 128
depth: 4
bias: true
alpha: 0.0
data:
val_size: 10000
train_size: 10000
batch_size: 256
num_workers: 3
ckpt_path: null
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.0
amsgrad: false
foreach: null
maximize: false
capturable: false
differentiable: false
fused: null
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: hp_metric
mode: min
factor: 0.05
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 10
min_lr: 0.0
eps: 1.0e-08
verbose: true

86
hsv2.txt Normal file
View File

@ -0,0 +1,86 @@
# lightning.pytorch==2.1.3
seed_everything: 31
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
logger: null
callbacks:
- class_path: callbacks.SaveImageCallback
init_args:
save_interval: 0
final_dir: out
fast_dev_run: false
max_epochs: 10
min_epochs: 10
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 50
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 3
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
transform: tanh
width: 256
depth: 8
bias: true
alpha: 0.0
data:
val_size: 10000
train_size: 10000
batch_size: 256
num_workers: 3
ckpt_path: null
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.01
amsgrad: false
maximize: false
foreach: null
capturable: false
differentiable: false
fused: null
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: hp_metric
mode: min
factor: 0.05
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 10
min_lr: 0.0
eps: 1.0e-08
verbose: true

86
hsv3.txt Normal file
View File

@ -0,0 +1,86 @@
# lightning.pytorch==2.1.3
seed_everything: 1009
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
logger: null
callbacks:
- class_path: callbacks.SaveImageCallback
init_args:
save_interval: 0
final_dir: out
fast_dev_run: false
max_epochs: 10
min_epochs: 10
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 50
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: null
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 3
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model:
transform: tanh
width: 512
depth: 2
bias: true
alpha: 0.75
data:
val_size: 10000
train_size: 10000
batch_size: 256
num_workers: 3
ckpt_path: null
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.01
amsgrad: false
maximize: false
foreach: null
capturable: false
differentiable: false
fused: null
lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
monitor: hp_metric
mode: min
factor: 0.05
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 10
min_lr: 0.0
eps: 1.0e-08
verbose: true

41
index.html Normal file
View File

@ -0,0 +1,41 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Discover the Rainbow</title>
<style>
body {
text-align: center;
font-family: Arial, sans-serif;
}
h1 {
margin-top: 50px;
}
.links {
margin-top: 20px;
}
.links a {
display: block;
margin: 10px 0;
font-size: 18px;
color: blue;
text-decoration: none;
}
.links a:hover {
text-decoration: underline;
}
</style>
</head>
<body>
<h1>Discover the Rainbow</h1>
<div class="links">
<a href="./out">Current</a>
<a href="./out1">Iteration 1</a>
<a href="./out2">Iteration 2</a>
<a href="./out3">Iteration 3</a>
<a href="./out4">Iteration 4 (good refactor, searching for mix of supervision)</a>
<a href="./out">Iteration 5 (all supervised)</a>
</div>
</body>
</html>

View File

@ -1,6 +1,6 @@
import torch
from utils import PURE_RGB
from utils import RGBMYC_ANCHOR
# def smoothness_loss(outputs):
# # Sort outputs for smoothness calculation
@ -17,55 +17,61 @@ from utils import PURE_RGB
# return smoothness_loss
def preservation_loss(inputs, outputs):
# Distance Preservation Component
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component (or scaled euclidean if given targets)
# Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None:
target_inputs = inputs
else:
assert target_outputs is not None
if target_outputs is None:
target_outputs = outputs
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
# max_rgb_distance = 1
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1))
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
rgb_norm = (
rgb_norm % 1
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent.
# yes it connects black and white, but also complimentary colors to primary
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu(
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1)
)
transformed_norm = circle_norm(outputs, target_outputs) * 2
diff = torch.abs(rgb_norm - transformed_norm)
# print(diff)
diff = torch.pow(rgb_norm - transformed_norm, 2)
# N = len(outputs)
# N = (N * (N - 1)) / 2
N = torch.count_nonzero(rgb_norm)
return torch.sum(diff) / N
return torch.mean(diff)
def circle_norm(vector, other_vector):
# Assumes vectors are of shape (N,1)
diff = torch.abs(vector - other_vector.T)
loss_a = torch.triu(diff)
loss_b = torch.triu(torch.abs(1 - diff))
loss = torch.minimum(loss_a, loss_b)
return loss
def separation_loss(red, green, blue):
# Separation Component
# TODO: remove
# Encourages the model to keep R, G, B values equally separated in the transformed space
red, green, blue = red % 1, green % 1, blue % 1
red_green_distance = torch.min(
torch.abs((red - green)), torch.abs((1 + red - green))
)
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue)))
green_blue_distance = torch.min(
torch.abs((green - blue)), torch.abs((1 + green - blue))
)
# print(red_green_distance, red_blue_distance, green_blue_distance)
# we want these distances to be equal to one another
return (
torch.abs(red_green_distance - red_blue_distance)
+ torch.abs(red_green_distance - green_blue_distance)
+ torch.abs(red_blue_distance - green_blue_distance)
)
red_loss = torch.abs(0 - red)
green_loss = torch.abs(1 / 3 - green) / (2 / 3)
blue_loss = torch.abs(2 / 3 - blue) / (2 / 3)
return red_loss + green_loss + blue_loss
def calculate_separation_loss(model):
# TODO: remove
# Wrapper function to calculate separation loss
outputs = model(PURE_RGB)
outputs = model(RGBMYC_ANCHOR.to(model.device))
red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue)

30
main.py
View File

@ -4,10 +4,10 @@ import random
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
from callbacks import SaveImageCallback
from dataloader import create_named_dataloader
from dataloader import create_named_dataloader as create_dataloader
from model import ColorTransformerModel
@ -40,6 +40,7 @@ def parse_args():
default=3,
help="Number of workers for data loading",
)
parser.add_argument("--width", type=int, default=128, help="Max width of network.")
# Parse arguments
args = parser.parse_args()
@ -61,13 +62,13 @@ if __name__ == "__main__":
seed_everything(args.seed)
early_stop_callback = EarlyStopping(
monitor="hp_metric", # Metric to monitor
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
patience=24, # Number of epochs with no improvement after which training will be stopped
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True,
)
# early_stop_callback = EarlyStopping(
# monitor="hp_metric", # Metric to monitor
# min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
# patience=5, # Number of epochs with no improvement after which training will be stopped
# mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
# verbose=True,
# )
save_img_callback = SaveImageCallback(
save_interval=0,
@ -76,8 +77,9 @@ if __name__ == "__main__":
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_named_dataloader(
N=0,
train_dataloader = create_dataloader(
# N=1e5,
skip=True,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,
@ -87,6 +89,10 @@ if __name__ == "__main__":
alpha=args.alpha,
learning_rate=args.lr,
batch_size=args.bs,
width=args.width,
bias=False,
transform="relu",
depth=1,
)
# Initialize model with parsed arguments
@ -95,7 +101,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
deterministic=True,
callbacks=[early_stop_callback, save_img_callback],
callbacks=[save_img_callback],
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)

View File

@ -1,20 +1,75 @@
lint:
black .
isort --profile=black .
flake8 --ignore E501,W503 .
isort --profile=black *.py
flake8 --ignore E501,W503,E203 *.py
test:
python main.py --alpha 2 --lr 2e-4 --max_epochs 200
# python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048
python newmain.py fit \
--seed_everything 21 \
--data.batch_size 256 \
--data.train_size 0 \
--data.val_size 100000 \
--model.alpha 0 \
--model.width 2048 \
--trainer.fast_dev_run 1 \
--trainer.min_epochs 1 \
--trainer.max_epochs 10 \
--trainer.log_every_n_steps 5 \
--trainer.check_val_every_n_epoch 1 \
--trainer.callbacks callbacks.SaveImageCallback \
--trainer.callbacks.init_args.final_dir out \
--trainer.callbacks.init_args.save_interval 0 \
--optimizer torch.optim.Adam \
--optimizer.init_args.lr 0.01 \
--lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
--lr_scheduler.init_args.patience 5 \
--lr_scheduler.init_args.cooldown 10 \
--lr_scheduler.init_args.factor 0.05 \
--lr_scheduler.init_args.monitor hp_metric \
--lr_scheduler.init_args.verbose true \
--print_config
search:
python search.py
help:
# python newmain.py fit --help --trainer.callbacks.help
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
python newmain.py fit --help
search: lint
python newsearch.py
hsv:
python hsv.py
# TODO: replace this with what we used in day-in-the-life
animate:
ffmpeg -i lightning_logs/version_258/e%04d.png \
-c:v libx264 \
-vf "fps=12,format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" \
~/animated.mp4
umap:
for seed in `seq 0 100`; do \
python scripts/sortcolor.py -s umap --dpi 300 --seed $$seed ; \
done
sort_umap:
python scripts/sortcolor.py -s umap --dpi 300 --seed 21
parallel_umap:
parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 1000)
sort_lex:
python scripts/sortcolor.py -s lex --dpi 300
sort_hsv:
python scripts/sortcolor.py -s hsv --dpi 300
clean:
rm -rf lightning_logs/*
rm out/*.png
rm -rf lightning_logs
rm -rf .lr_find_*.ckpt
rm -f out/*.png out/*.txt
rm -rf __pycache__/
cp hsv.png out/

173
model.py
View File

@ -1,106 +1,57 @@
import pytorch_lightning as pl
import lightning as L
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(5, 128, bias=False),
# nn.Linear(128, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 64, bias=False),
# nn.Linear(64, 128, bias=False),
# nn.Linear(128, 256, bias=False),
# nn.Linear(256, 128, bias=False),
# nn.ReLU(),
# nn.Linear(128, 1, bias=False),
# )
# def forward(self, x):
# x = self.layers(x)
# x = (torch.sin(x) + 1) / 2
# return x
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Embedding layer to expand the input dimensions
# self.embedding = nn.Linear(3, 128, bias=False)
# # Transformer encoder-decoder
# encoder = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_encoder = nn.TransformerEncoder(
# encoder, num_layers=3
# )
# # lower dimensionality decoder
# decoder = nn.TransformerDecoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_decoder = nn.TransformerDecoder(
# decoder, num_layers=3
# )
# # Final linear layer to map back to 1D space
# self.final_layer = nn.Linear(128, 1, bias=False)
# def forward(self, x):
# # Embedding the input
# x = self.embedding(x)
# # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Passing through the decoder
# x = self.transformer_decoder(x, memory=x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# # x = torch.sigmoid(x)
# x = (torch.sin(x) + 1) / 2
# return x
from losses import circle_norm, preservation_loss # noqa: F401
from utils import RGBMYC_ANCHOR
class ColorTransformerModel(pl.LightningModule):
def __init__(self, params):
class ColorTransformerModel(L.LightningModule):
def __init__(
self,
transform: str = "relu",
width: int = 128,
depth: int = 1,
bias: bool = False,
alpha: float = 0,
lr: float = 0.01,
loop: bool = False,
dropout=0.5,
):
super().__init__()
self.save_hyperparameters(params)
self.save_hyperparameters()
if self.hparams.transform.lower() == "tanh":
t = nn.Tanh
elif self.hparams.transform.lower() == "relu":
t = nn.ReLU
w = self.hparams.width
d = self.hparams.depth
bias = self.hparams.bias
if self.hparams.loop:
midlayers = []
midlayers += [nn.Linear(w, w, bias=bias), t()] * d
else:
midlayers = sum(
[
[nn.Linear(w, w, bias=bias), nn.Dropout(self.dropout), t()]
for _ in range(d)
],
[],
)
# Neural network layers
self.network = nn.Sequential(
nn.Linear(3, 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
nn.Linear(16, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Linear(3, w, bias=bias),
t(),
*midlayers,
nn.Linear(w, 3, bias=bias),
t(),
nn.Linear(3, 1, bias=bias),
)
def forward(self, x):
# Pass the input through the network
x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2
@ -110,25 +61,51 @@ class ColorTransformerModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
s_loss = calculate_separation_loss(model=self)
rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
p_loss = preservation_loss(
inputs,
outputs,
# target_inputs=rgb_tensor,
# target_outputs=self.forward(rgb_tensor),
)
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
self.log("hp_metric", loss)
self.log("p_loss", p_loss)
self.log("s_loss", s_loss)
# N = len(outputs)
# distance = circle_norm(outputs, labels).mean()
distance = torch.norm(outputs - labels).mean()
# Backprop with this:
loss = (1 - alpha) * p_loss + alpha * distance
# p_loss is unsupervised (preserve relative distances - either in-batch or to-target)
# distance is supervised.
self.log("hp_metric", distance)
# Log all losses individually
self.log("train_pres", p_loss)
self.log("train_mse", distance)
self.log("train_loss", loss)
return loss
def validation_step(self, batch):
inputs, labels = batch # these are true HSV labels - no learning allowed.
outputs = self.forward(inputs)
# distance = torch.minimum(
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
# )
distance = torch.norm(outputs - labels)
mean_loss = torch.mean(distance)
max_loss = torch.max(distance)
self.log("val_mse", mean_loss)
self.log("val_max", max_loss)
return mean_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.hparams.learning_rate,
lr=self.hparams.lr,
)
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.05, patience=10, cooldown=20, verbose=True
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
)
return {
"optimizer": optimizer,

15
newmain.py Normal file
View File

@ -0,0 +1,15 @@
from lightning.pytorch.cli import LightningCLI
# from callbacks import SaveImageCallback
from datamodule import ColorDataModule
from model import ColorTransformerModel
def cli_main():
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841
# note: don't call fit!!
if __name__ == "__main__":
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block

113
newsearch.py Normal file
View File

@ -0,0 +1,113 @@
import subprocess
import sys
from random import sample, seed
import numpy as np # noqa: F401
from lightning_sdk import Machine, Studio # noqa: F401
# consistency of randomly sampled experiments.
seed(19920921)
NUM_JOBS = 100
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
# studio = Studio()
# use the jobs plugin
# studio.install_plugin("jobs")
# job_plugin = studio.installed_plugins["jobs"]
# do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-3]
# learning_rate_values = [5e-4]
# alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
alpha_values = [0]
# widths = [2**k for k in range(4, 13)]
# depths = [1, 2, 4, 8, 16]
widths, depths = [512], [4]
batch_size_values = [256]
max_epochs_values = [100]
seeds = list(range(21, 1992))
optimizers = [
# "Adagrad",
"Adam",
# "SGD",
# "AdamW",
# "LBFGS",
# "RAdam",
# "RMSprop",
# "Adadelta",
]
# Generate all possible combinations of hyperparameters
all_params = [
(alpha, lr, bs, me, s, w, d, opt)
for alpha in alpha_values
for lr in learning_rate_values
for bs in batch_size_values
for me in max_epochs_values
for s in seeds
for w in widths
for d in depths
for opt in optimizers
]
# perform random search with a limit
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
# --trainer.callbacks+ lightning.pytorch.callbacks.EarlyStopping \
# --trainer.callbacks.init_args.monitor hp_metric \
for idx, params in enumerate(search_params):
a, lr, bs, me, s, w, d, opt = params
# cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}"
cmd = f"""
python newmain.py fit \
--seed_everything {s} \
--data.batch_size {bs} \
--data.train_size 0 \
--data.val_size 10000 \
--model.alpha {a} \
--model.width {w} \
--model.depth {d} \
--model.bias true \
--model.loop true \
--model.transform tanh \
--trainer.min_epochs 10 \
--trainer.max_epochs {me} \
--trainer.log_every_n_steps 3 \
--trainer.check_val_every_n_epoch 1 \
--trainer.limit_val_batches 50 \
--trainer.callbacks callbacks.SaveImageCallback \
--trainer.callbacks.init_args.final_dir out \
--trainer.callbacks.init_args.save_interval 0 \
--optimizer torch.optim.{opt} \
--optimizer.init_args.lr {lr} \
--trainer.callbacks+ lightning.pytorch.callbacks.LearningRateFinder
# --lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
# --lr_scheduler.init_args.monitor hp_metric \
# --lr_scheduler.init_args.factor 0.05 \
# --lr_scheduler.init_args.patience 5 \
# --lr_scheduler.init_args.cooldown 10 \
# --lr_scheduler.init_args.verbose true
"""
# job_name = f"color2_{bs}_{a}_{lr:2.2e}"
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
print(f"Running {params}: {cmd}")
try:
# Run the command and wait for it to complete
# subprocess.run(test_cmd, shell=True, check=True)
subprocess.run(cmd, shell=True, check=True)
except KeyboardInterrupt:
print("Interrupted by user")
sys.exit(1)
# except subprocess.CalledProcessError:
# pass

View File

@ -72,8 +72,13 @@
function loadImages() {
var gallery = document.getElementById('gallery');
for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100
let imageName = 'v' + i + '.png';
for (var i = 0; i < 200; i++) { // Changed from i <= 100 to i < 100
let imageName;
if (i == -21) {
imageName = 'hsv.png';
} else {
imageName = 'v' + i + '.png';
}
let img = document.createElement('img');
img.src = imageName;
img.onerror = function () { this.style.display = 'none'; };

210
requirements.txt Normal file
View File

@ -0,0 +1,210 @@
absl-py==2.0.0
aiobotocore==2.11.0
aiofiles==22.1.0
aiohttp==3.9.1
aioitertools==0.11.0
aiosignal==1.3.1
aiosqlite==0.19.0
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
backoff==2.2.1
beautifulsoup4==4.12.2
bitsandbytes==0.42.0
black==23.12.1
bleach==6.1.0
blessed==1.20.0
boto3==1.34.17
botocore==1.34.17
cachetools==5.3.2
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.1
contourpy==1.2.0
croniter==1.4.1
cycler==0.12.1
dateutils==0.6.12
debugpy==1.8.0
decorator==5.1.1
deepdiff==6.7.1
defusedxml==0.7.1
docker==6.1.3
docstring-parser==0.15
editor==1.6.5
entrypoints==0.4
exceptiongroup==1.2.0
executing==2.0.1
fastapi==0.109.0
fastjsonschema==2.19.1
filelock==3.13.1
flake8==7.0.0
fonttools==4.47.2
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2023.12.2
google-auth==2.26.2
google-auth-oauthlib==1.2.0
grpcio==1.60.0
h11==0.14.0
hydra-core==1.3.2
idna==3.6
importlib-metadata==7.0.1
importlib-resources==6.1.1
iniconfig==2.0.0
inquirer==3.2.1
ipykernel==6.26.0
ipython==8.17.2
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
isort==5.13.2
jedi==0.19.1
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.3.2
json5==0.9.14
jsonargparse==4.27.2
jsonpointer==2.4
jsonschema==4.20.0
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-ydoc==0.2.5
jupyter_client==7.4.9
jupyter_core==5.7.1
jupyter_server==2.12.4
jupyter_server_fileid==0.9.1
jupyter_server_terminals==0.5.1
jupyter_server_ydoc==0.6.1
jupyterlab==3.6.1
jupyterlab-widgets==3.0.9
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
kiwisolver==1.4.5
lightning==2.1.3
lightning-api-access==0.0.5
lightning-cloud==0.5.57
lightning-fabric==2.1.3
lightning-utilities==0.10.0
lightning_sdk==0.0.13a0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.2
matplotlib-inline==0.1.6
mccabe==0.7.0
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
nbclassic==1.0.0
nbclient==0.9.0
nbconvert==7.14.1
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2.1
notebook==6.5.6
notebook_shim==0.2.3
numpy==1.26.2
oauthlib==3.2.2
omegaconf==2.3.0
ordered-set==4.1.0
overrides==7.4.0
packaging==23.2
pandas==2.1.4
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.12.1
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.1.0
pluggy==1.4.0
prometheus-client==0.19.0
prompt-toolkit==3.0.43
protobuf==4.23.4
psutil==5.9.7
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycodestyle==2.11.1
pycparser==2.21
pydantic==2.5.3
pydantic_core==2.14.6
pyflakes==3.2.0
Pygments==2.17.2
PyJWT==2.8.0
pyparsing==3.1.1
pytest==7.4.4
python-dateutil==2.8.2
python-json-logger==2.0.7
python-multipart==0.0.6
pytorch-lightning==2.1.2
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==24.0.1
readchar==4.0.5
redis==5.0.1
referencing==0.32.1
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rpds-py==0.17.1
rsa==4.9
runs==1.2.0
s3fs==2023.12.2
s3transfer==0.10.0
scikit-learn==1.3.2
scipy==1.11.4
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.35.1
sympy==1.12
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
terminado==0.18.0
threadpoolctl==3.2.0
tinycss2==1.2.1
tomli==2.0.1
torch==2.1.1+cu118
torchmetrics==1.3.0.post0
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
triton==2.1.0
types-python-dateutil==2.8.19.20240106
typeshed-client==2.4.0
typing_extensions==4.9.0
tzdata==2023.4
uri-template==1.3.0
urllib3==2.0.7
uvicorn==0.25.0
watermark==2.4.3
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
websockets==11.0.3
Werkzeug==3.0.1
widgetsnbextension==4.0.9
wrapt==1.16.0
xmod==1.8.1
y-py==0.6.2
yarl==1.9.4
ypy-websocket==0.8.4
zipp==3.17.0

64
scripts/place.sh Normal file
View File

@ -0,0 +1,64 @@
#!/bin/bash
# Constants
# SEED=20230920
SEED=0
# INPUT_FILE="${DIR}/arrangement_$SEED.txt"
DIR=/teamspace/studios/this_studio/out_sortcolors
INPUT_FILE="${DIR}/arrangement_grid.txt"
# OUTPUT_IMAGE="${DIR}/circle_composite_$SEED.png"
TYPE=circle
OUTPUT_IMAGE="${DIR}/${TYPE}_composite_grid.png"
DPI=150
CANVAS_SIZE=$((72*$DPI)) # 72 inches
CIRCLE_IMAGE="${DIR}/hsv/sorted_colors_circle.png"
KIND=umap
identify $CIRCLE_IMAGE
# PREP
echo "Building ops"
# Build the composite operations string
composite_ops=""
idx=1
while IFS=, read -r x y; do
# Translate so that (0,0) becomes the center of the canvas
fx=$(echo "$x*$DPI + $CANVAS_SIZE/2" | bc -l)
fy=$(echo "$CANVAS_SIZE/2 - $y*$DPI" | bc -l)
# Convert the final float values to integers
ix=$(printf "%.0f" "$fx")
iy=$(printf "%.0f" "$fy")
if [[ idx -eq 42 ]]; then
CIRCLE_IMAGE="${DIR}/hsv/sorted_colors_${TYPE}.png"
else
idx_str=$(printf "%04d" "$idx")
CIRCLE_IMAGE="${DIR}/${KIND}/${idx_str}_sorted_colors_${TYPE}.png"
# CIRCLE_IMAGE="${DIR}/hsv_sorted_colors_${TYPE}.png"
fi
# Add to the composite operations string
composite_ops="$composite_ops \( $CIRCLE_IMAGE \) -resize 50% -compose Over -geometry +$ix+$iy -composite"
idx=$((idx+1))
done < $INPUT_FILE
# COMPOSITE
echo "Compositing"
# Use convert with the built composite operations string
eval "convert -units PixelsPerInch \
-size ${CANVAS_SIZE}x${CANVAS_SIZE} xc:white \
$composite_ops \
-density $DPI \
$OUTPUT_IMAGE"
echo "Saved $OUTPUT_IMAGE"
# DEBUG
# eval "convert -units PixelsPerInch \
# $OUTPUT_IMAGE \
# \( ${DIR}/arrangement_$SEED.png -evaluate Multiply 0.5 \) \
# -gravity center -composite \
# -density $DPI \
# /tmp/debug.png"
# open /tmp/debug.png

View File

@ -0,0 +1,5 @@
pip install \
--extra-index-url=https://pypi.nvidia.com \
cudf-cu12==23.12.* cuml-cu12==23.12.* \
dask-cudf-cu12==23.12.* pylibraft-cu12==23.12.* \
raft-dask-cu12==23.12.*

View File

@ -0,0 +1,6 @@
pip install \
--extra-index-url=https://pypi.nvidia.com \
cudf-cu12==23.12.* dask-cudf-cu12==23.12.* cuml-cu12==23.12.* \
cugraph-cu12==23.12.* cuspatial-cu12==23.12.* cuproj-cu12==23.12.* \
cuxfilter-cu12==23.12.* cucim-cu12==23.12.* pylibraft-cu12==23.12.* \
raft-dask-cu12==23.12.*

148
scripts/scatter.py Normal file
View File

@ -0,0 +1,148 @@
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
def get_quadrant(x, y):
"""Return the quadrant of a given point."""
if x >= 0 and y >= 0:
return 1
elif x < 0 and y >= 0:
return 2
elif x < 0 and y < 0:
return 3
else:
return 4
def is_overlapping(x, y, existing_points, radius):
"""Check if a point overlaps with existing points, crosses a quadrant or goes beyond the quadrant lines."""
current_quadrant = get_quadrant(x, y)
buffer = 0.25
# Check if circle touches/crosses the x=0 or y=0 lines.
if (
abs(x) + radius > buffer
and abs(x) - radius < buffer
or abs(y) + radius > buffer
and abs(y) - radius < buffer
):
return True
for ex, ey in existing_points:
# Check overlap with existing points
if np.sqrt((x - ex) ** 2 + (y - ey) ** 2) <= 2 * radius + 0.5:
return True
# Check if the circle touches another quadrant
if (
get_quadrant(ex, ey) != current_quadrant
and np.sqrt((x - ex) ** 2 + (y - ey) ** 2) <= radius
):
return True
return False
DIR = "/teamspace/studios/this_studio/out_sortcolors"
N = 100
DPI = 300
SIZE = 72 # canvas size?
radius = 3
variance = 72 # Adjust variance as needed
Path(DIR).mkdir(exist_ok=True, parents=True)
with open(f"{DIR}/arrangement_grid.txt", "w") as f:
radius = 3
half = SIZE / 2 - 4
# make points an equispaced grid of 10 x 10 ranging from [-35, 35]
interval = (half - (-half)) / 9 # 10 points, so 9 intervals
# Generate the grid points
points = [
(x, y)
for x in np.arange(-half, half + 0.1, interval)
for y in np.arange(-half, half + 0.1, interval)
]
for x, y in points:
f.write(f"{x-radius}, {y+radius}\n")
# f.write(f"{x}, {y}\n")
print("wrote grid")
for seed in range(11, 22):
try:
np.random.seed(seed)
# To store plotted points
points = []
max_iterations = int(1e7)
iterations = 0
# Generate points
for k in range(N):
while True:
# x, y = np.random.normal(0, variance), np.random.normal(0, variance)
random_angle = np.random.uniform(0, 2 * np.pi)
# random_radius = np.random.uniform(0.25+radius, SIZE/2 - radius - 0.25)
random_radius = abs(np.random.normal(0, variance))
x = random_radius * np.cos(random_angle)
y = random_radius * np.sin(random_angle)
iterations += 1
if not is_overlapping(x, y, points, radius):
if max(abs(x), abs(y)) + radius < SIZE / 2 - 0.25:
points.append((x, y))
break
if iterations > max_iterations:
raise ValueError(f"Too many iterations: {k} points")
print(f"{k}: ({x}, {y}) @ {iterations:09d}")
# Create plot with circles
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
for x, y in points:
circle = patches.Circle((x, y), radius, color="black")
ax.add_patch(circle)
# Draw the standard quadrants
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)
# Use square axis and set limits
lim_x = (
max(
abs(max(points, key=lambda t: t[0])[0]),
abs(min(points, key=lambda t: t[0])[0]),
)
+ radius
)
lim_y = (
max(
abs(max(points, key=lambda t: t[1])[1]),
abs(min(points, key=lambda t: t[1])[1]),
)
+ radius
)
lim_x = lim_y = SIZE / 2
ax.axis("off")
ax.set_aspect("equal")
ax.set_xlim(-lim_x, lim_x)
ax.set_ylim(-lim_y, lim_y)
# Save and show
fig.tight_layout(pad=0)
plt.savefig(
f"{DIR}/arrangement_{seed}.png", dpi=DPI, bbox_inches="tight", pad_inches=0
)
# plt.show()
# also save x/y coords as text file
with open(f"{DIR}/arrangement_{seed}.txt", "w") as f:
for x, y in points:
f.write(f"{x-radius}, {y+radius}\n")
# f.write(f"{x}, {y}\n")
except ValueError as e:
print(f"{seed}: {e}")
except AssertionError as e:
print(f"{seed}: {e}")

379
scripts/sortcolor.py Normal file
View File

@ -0,0 +1,379 @@
import argparse
from pathlib import Path
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from hilbertcurve.hilbertcurve import HilbertCurve
# Extract XKCD colors
colors = list(mcolors.XKCD_COLORS.keys())
rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors]
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting")
parser.add_argument("--seed", type=int, default=21, help="seed for UMAP")
parser.add_argument("--dpi", type=int, default=100, help="dpi for saving")
parser.add_argument("--size", type=float, default=6.0, help="size of figure")
parser.add_argument(
"--fontsize",
type=float,
default=0,
help="fontsize of annotation (default: 0 = None)",
)
parser.add_argument(
"--radius", type=float, default=1 / 3, help="inner radius of circle"
)
args = parser.parse_args()
KIND = args.sort_by
SEED = args.seed
DPI = args.dpi
SIZE = args.size
FONTSIZE = args.fontsize
INNER_RADIUS = args.radius
DIR = "/teamspace/studios/this_studio/out_sortcolors"
prefix = ""
if KIND == "umap":
prefix = f"{SEED:04d}_"
FDIR = f"{DIR}/{KIND}"
Path(FDIR).mkdir(exist_ok=True, parents=True)
fname = f"{FDIR}/{prefix}sorted_colors_circle.png"
def peano_curve(n):
"""
Generate Peano curve coordinates for a given order `n`.
"""
if n == 0:
return [(0, 0)]
prev_curve = peano_curve(n - 1)
max_coord = 3 ** (n - 1)
# Define the transformations for the Peano curve's 9 segments
transforms = [
lambda x, y: (x, y), # Bottom-left square
lambda x, y: (x + max_coord, y), # Bottom-middle square
lambda x, y: (x + 2 * max_coord, y), # Bottom-right square
lambda x, y: (x + 2 * max_coord, y + max_coord), # Middle-right square
lambda x, y: (
x + max_coord,
y + max_coord,
), # Center square (traversed in reverse)
lambda x, y: (x, y + max_coord), # Middle-left square
lambda x, y: (x, y + 2 * max_coord), # Top-left square
lambda x, y: (x + max_coord, y + 2 * max_coord), # Top-middle square
lambda x, y: (x + 2 * max_coord, y + 2 * max_coord), # Top-right square
]
curve = []
for transform in transforms:
segment = [transform(x, y) for x, y in prev_curve]
# Reverse the traversal for the center square
if transform == transforms[4]:
segment = segment[::-1]
curve += segment
return curve
if KIND in ("lex", "alpha", "abc"):
preds = np.array(colors)
elif KIND == "umap":
# from umap import UMAP
from cuml import UMAP
# Use UMAP to create a 1D representation
reducer = UMAP(
n_components=1,
n_neighbors=250,
min_dist=1e-2,
metric="euclidean",
random_state=SEED,
negative_sample_rate=2,
)
embedding = reducer.fit_transform(np.array(rgb_values))
# Sort colors by the 1D representation
preds = embedding[:, 0]
del reducer, embedding
elif KIND in ("cielab", "lab", "ciede2000"):
from skimage.color import deltaE_ciede2000, rgb2lab
# CIELAB
# Convert RGB values to CIELAB
lab_values = rgb2lab([rgb_values])
# Reference color for sorting (can be the first color or any other reference point)
reference_color = lab_values[0]
# Compute CIEDE2000 distances of all colors to the reference color
distances = [deltaE_ciede2000(reference_color, color) for color in lab_values]
# Sort colors by their CIEDE2000 distance to the reference color
# preds = distances).flatten() # awful
lab_values_flat = lab_values.reshape(-1, 3)
# Sort colors based on the L* value in the CIELAB space
# 0 corresponds to the L* channel
preds = lab_values_flat[:, 0]
elif KIND == "hsv":
from matplotlib.colors import rgb_to_hsv
# Convert RGB values to HSV
hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values])
# Sort colors based on the hue value
# 0 corresponds to the hue component
preds = hsv_values[:, 0]
else:
raise ValueError(f"Unknown kind: {KIND}")
sorted_indices = np.argsort(preds)
# Save the sorted indices to disk
# if (KIND == "umap") or (KIND != "umap"):
PDIR = f"scripts/{KIND}"
Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{SEED:06d}.npy"
np.save(file_path, preds.ravel())
print(f"Predictions saved to {file_path}")
# Sort colors by the 1D representation
sorted_colors = [colors[i] for i in sorted_indices]
# # Display the sorted colors around the ring of a circle
# # Create a new figure for the circle visualization
# fig, ax = plt.subplots(figsize=(SIZE, SIZE))
# # Circle parameters
# center = (0, 0)
# radius = 1.0
# # Angle increment (in radians) based on the number of colors
# angle_increment = 2 * np.pi / len(sorted_colors)
# # roll the colors so that the first color is white
# reordered_sorted_colors = sorted_colors.copy()
# white_index = reordered_sorted_colors.index("xkcd:white")
# reordered_sorted_colors = reordered_sorted_colors[white_index:] + reordered_sorted_colors[:white_index]
# # Plot each color around the circle
# for i, color in enumerate(reordered_sorted_colors):
# # Compute start and end angles for the segment
# start_angle = i * angle_increment
# end_angle = (i + 1) * angle_increment
# # Create a wedge (segment of the circle) for each color
# wedge = patches.Wedge(
# center, radius, 90 + np.degrees(start_angle), 90 + np.degrees(end_angle), fc=color
# )
# ax.add_patch(wedge)
# # Overlay a white circle in the center
# inner_circle = patches.Circle(center, INNER_RADIUS, color="white")
# ax.add_patch(inner_circle)
# if INNER_RADIUS > 0.0:
# fcolor = "black"
# else:
# fcolor = "white"
# if FONTSIZE > 0.0:
# ax.annotate(f"{KIND.upper()}", center, ha="center", va="center", size=FONTSIZE, color=fcolor)
# # Set equal scaling and remove axis
# ax.set_aspect("equal")
# ax.axis("off")
# ax.set_ylim(-radius, radius)
# ax.set_xlim(-radius, radius)
# # Save and display the circle visualization
# prefix = ""
# if KIND == "umap":
# prefix = f"{SEED:02d}"
# fname = f"{DIR}/{prefix}{KIND}_sorted_colors_circle.png"
# fig.tight_layout(pad=0)
# fig.savefig(
# fname,
# dpi=DPI,
# transparent=True,
# pad_inches=0,
# bbox_inches="tight",
# )
# print(f"Saved {fname}")
# # plt.show()
def plot_preds(
preds,
rgb_values,
fname: str,
roll: bool = False,
dpi: int = 150,
inner_radius: float = 1 / 3,
figsize=(3, 3),
):
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])
white_idx = np.where((colors == white).all(axis=1))
if white_idx:
white_idx = white_idx[0][0]
colors = np.roll(colors, -white_idx, axis=0)
else:
print("no white, skipping")
# print(white_idx, colors[:2])
N = len(colors)
# Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
# Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
width = 2 * np.pi / (N) # equal size for each wedge
for i in range(N):
ax.bar(
# 2 * np.pi * preds[i],
theta[i],
height=1,
width=width,
edgecolor=colors[i],
linewidth=0.25,
# facecolor=[rgb_values[i][1]]*3,
# facecolor=rgb_values[i],
facecolor=colors[i],
bottom=0.0,
zorder=1,
alpha=1,
align="edge",
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal")
ax.axis("off")
radius = 1
ax.set_ylim(0, radius)
# Overlay white circle
circle = patches.Circle(
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
)
ax.add_patch(circle)
if FONTSIZE > 0.0:
center = (0, 0)
fcolor = "black"
ax.annotate(
f"{KIND.upper()}",
center,
ha="center",
va="center",
size=FONTSIZE,
color=fcolor,
)
fig.tight_layout(pad=0)
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
plt.close()
plot_preds(
preds,
np.array(rgb_values),
fname,
roll=False,
dpi=DPI,
inner_radius=INNER_RADIUS,
figsize=(SIZE, SIZE),
)
print(f"saved {fname}")
HILBERT = False
if HILBERT:
# Create Hilbert curve
# We'll set the order such that the number of positions is greater than or equal to the number of colors
hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors))))
hilbert_curve = HilbertCurve(hilbert_order, 2)
# Create an image for visualization
image_size = 2**hilbert_order
image = np.ones((image_size, image_size, 3))
for i, color in enumerate(sorted_colors):
# Convert linear index to Hilbert coordinates
coords = hilbert_curve.point_from_distance(i)
image[coords[1], coords[0]] = mcolors.to_rgb(color)
# annotation in upper right
# Display the image
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
ax.imshow(image)
ax.annotate(
f"{KIND.upper()}",
(1.0, 1.0),
ha="right",
va="top",
size=FONTSIZE,
xycoords="axes fraction",
)
ax.axis("off")
ax.set_aspect("equal")
fig.tight_layout()
fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png"
fig.savefig(
fname,
dpi=DPI,
transparent=True,
# bbox_inches="tight",
# pad_inches=0
)
print(f"Saved {fname}")
# plt.show()
# # Create peano curve
# order = 0
# while (3**order)**2 < len(sorted_colors):
# order += 1
# # Generate peano curve coordinates for the determined order
# curve_coords = peano_curve(order)
# unique_coords = set(curve_coords)
# print(f"Total coordinates: {len(curve_coords)}")
# print(f"Unique coordinates: {len(unique_coords)}")
# # If there are more points on the curve than colors, truncate the curve
# if len(curve_coords) > len(sorted_colors):
# print(f"Will be missing {len(curve_coords) - len(sorted_colors)} pixels, percentage: {100 * (len(curve_coords) - len(sorted_colors)) / len(curve_coords)}")
# curve_coords = curve_coords[:len(sorted_colors)]
# if len(curve_coords) < len(sorted_colors):
# raise ValueError("Not enough curve coordinates for the number of colors")
# # Create an image for visualization
# image_size = (3**order)
# image = np.ones((image_size, image_size, 3))
# for i, color in enumerate(sorted_colors):
# # Get the Moore curve coordinates for the current index
# coords = curve_coords[i]
# image[coords[1], coords[0]] = mcolors.to_rgb(color)
# # Display the image
# plt.figure(figsize=(8, 8))
# plt.imshow(image)
# plt.axis("off")
# plt.savefig(f"{DIR}/{KIND}_sorted_colors_moore.png", dpi=DPI)
# plt.show()

67
scripts/vips_composite.py Normal file
View File

@ -0,0 +1,67 @@
import pyvips
import glob
import random
import os
import math
import numpy as np
def bg_fill(width: int, height: int, channels: int = 3, fill: int = 255):
a = np.ones(shape=(height, width, channels)) * fill
return pyvips.Image.new_from_array(a)
def create_grid_composite(
directory, k, spacing_inches, output_file="output.png", dpi=300
):
# Calculate spacing in pixels
spacing_pixels = int(spacing_inches * dpi)
# Glob for PNG images
png_files = glob.glob(os.path.join(directory, "*.png"))
# Randomly select K^2 images
selected_files = random.sample(png_files, k * k)
# Create an empty list to hold the images
images = [
pyvips.Image.new_from_file(file, access="sequential") for file in selected_files
]
# Calculate the size of the composite image
# widths, heights = zip(*[image.size for image in images])
widths, heights = [1800], [1800]
max_width = max(widths)
max_height = max(heights)
# Calculate total size of the grid including spacing
total_width = k * max_width + (k + 1) * spacing_pixels
total_height = k * max_height + (k + 1) * spacing_pixels
# Create a blank image for the composite
# composite = pyvips.Image.black(total_width, total_height, bands=4)
composite = bg_fill(total_width, total_height, channels=1, fill=255)
# Place images into the composite
for i, image in enumerate(images):
row = i // k
col = i % k
x = col * (max_width + spacing_pixels) + spacing_pixels
y = row * (max_height + spacing_pixels) + spacing_pixels
composite = composite.insert(image.flatten(background=[255, 255, 255]), x, y)
# Save the composite image
composite.write_to_file(output_file)
x = pyvips.Image.thumbnail(output_file, 1080 * 4)
x.write_to_file("out_sm.png")
print(output_file)
if __name__ == "__main__":
# Example usage
directory = "/teamspace/studios/this_studio/out_sortcolors/umap/" # Change to your directory path
k = 10
spacing_inches = 0.5 # Half an inch between images
create_grid_composite(directory, k, spacing_inches)

View File

@ -2,37 +2,41 @@ import subprocess
import sys
from random import sample
import numpy as np
import numpy as np # noqa: F401
from lightning_sdk import Machine, Studio # noqa: F401
NUM_JOBS = 100
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
studio = Studio()
# studio = Studio()
# use the jobs plugin
studio.install_plugin("jobs")
job_plugin = studio.installed_plugins["jobs"]
# studio.install_plugin("jobs")
# job_plugin = studio.installed_plugins["jobs"]
# do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter
alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 41), 5))
learning_rate_values = [5e-4]
batch_size_values = [128]
max_epochs_values = [500]
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-2]
alpha_values = [0, 1, 2]
widths = [2**k for k in range(4, 15)]
# learning_rate_values = [5e-4]
batch_size_values = [256]
max_epochs_values = [100]
seeds = list(range(21, 1992))
# Generate all possible combinations of hyperparameters
all_params = [
(alpha, lr, bs, me, s)
(alpha, lr, bs, me, s, w)
for alpha in alpha_values
for lr in learning_rate_values
for bs in batch_size_values
for me in max_epochs_values
for s in seeds
for w in widths
]
@ -40,8 +44,8 @@ all_params = [
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
for idx, params in enumerate(search_params):
a, lr, bs, me, s = params
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s}"
a, lr, bs, me, s, w = params
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}"
# job_name = f"color2_{bs}_{a}_{lr:2.2e}"
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
print(f"Running {params}: {cmd}")

View File

@ -2,7 +2,7 @@ import matplotlib.colors as mcolors
import torch
def preprocess_data(data, skip=True):
def preprocess_data(data, skip: bool = True):
# Assuming 'data' is a tensor of shape [n_samples, 3]
if not skip:
# Compute argmin and argmax for each row
@ -34,6 +34,9 @@ def extract_colors():
return rgb_tensor, xkcd_color_names
PURE_RGB = preprocess_data(
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
RGBMYC_ANCHOR = preprocess_data(
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
)
PURE_HSV = torch.tensor(
[[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32
)