Compare commits
81 Commits
ab6f1d3368
...
f0a4c940af
Author | SHA1 | Date | |
---|---|---|---|
|
f0a4c940af | ||
|
04ff370bbb | ||
|
90f08a5edb | ||
|
831f17a9b4 | ||
|
698f6f1b51 | ||
|
13dd754ba6 | ||
|
a97055f282 | ||
|
d1b963e05d | ||
|
f3bc1e90d3 | ||
|
b95cb8df3d | ||
|
527cc285f3 | ||
|
95a6adac91 | ||
|
447b773d03 | ||
|
00cba285c8 | ||
|
9953327d30 | ||
|
0ae8414481 | ||
|
2af491c324 | ||
|
5687f30818 | ||
|
5def982f12 | ||
|
a9e772f34e | ||
|
697030b7df | ||
|
f450611ce3 | ||
|
0c4990e98c | ||
|
05dd4e29ce | ||
|
d318480b7c | ||
|
1c116f3f12 | ||
|
c7ffd09fb4 | ||
|
4342a54cc8 | ||
|
248d1a72f9 | ||
|
9e4861a272 | ||
|
e5b6f287a3 | ||
|
865e7f5104 | ||
|
b6d9f94d8e | ||
|
953488be4c | ||
|
0e561aae4c | ||
|
30470f13bc | ||
|
70f56ff9f0 | ||
|
3adcc9779a | ||
|
a49f166252 | ||
|
a44580a15b | ||
|
eee8a8b0ba | ||
|
6260e7fdd3 | ||
|
467b3f7e57 | ||
|
1da8d3194a | ||
|
b5d9e725b3 | ||
|
721993d9e5 | ||
|
b9d334e49a | ||
|
07b4e548e2 | ||
|
948bc31861 | ||
|
fff2b88fa1 | ||
|
49e6260346 | ||
|
6d40d39097 | ||
|
3b700aee70 | ||
|
1e818aa977 | ||
|
1ea29ba11e | ||
|
e1ac3211b9 | ||
|
012c7b7c68 | ||
|
88c8cde9f6 | ||
|
b9b6ee7727 | ||
|
e61543299d | ||
|
2ec2d1f368 | ||
|
50628e594a | ||
|
a02a662b6f | ||
|
7ce24b0cd3 | ||
|
947a7d4a56 | ||
|
0366b5d0f1 | ||
|
70ecd7d7db | ||
|
62e0764a70 | ||
|
72a1ad2971 | ||
|
6899320927 | ||
|
999d73f7ab | ||
|
e709d8f34f | ||
|
686e096b97 | ||
|
1f7d4c1890 | ||
|
5ed305fe34 | ||
|
7461406cf2 | ||
|
0934ca0aed | ||
|
a8b723f021 | ||
|
1f96c65b21 | ||
|
b5bb3fe3df | ||
|
a8d62bfca0 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,3 +5,6 @@ out/
|
||||
.sw[opqr]
|
||||
*.tar.gz
|
||||
.pat
|
||||
out*
|
||||
.lr*
|
||||
*.npy
|
||||
|
@ -1,11 +1,11 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning import Callback
|
||||
|
||||
from check import create_circle
|
||||
|
||||
|
||||
class SaveImageCallback(pl.Callback):
|
||||
class SaveImageCallback(Callback):
|
||||
def __init__(self, save_interval=1, final_dir: str = None):
|
||||
self.save_interval = save_interval
|
||||
self.final_dir = final_dir
|
||||
|
90
check.py
90
check.py
@ -1,4 +1,7 @@
|
||||
# import matplotlib.patches as patches
|
||||
from typing import Union
|
||||
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -9,7 +12,7 @@ from model import ColorTransformerModel
|
||||
# import matplotlib.colors as mcolors
|
||||
|
||||
|
||||
def make_image(ckpt: str, fname: str, color=True):
|
||||
def make_image(ckpt: str, fname: str, color=True, **kwargs):
|
||||
M = ColorTransformerModel.load_from_checkpoint(ckpt)
|
||||
|
||||
# preds = M(rgb_tensor)
|
||||
@ -35,52 +38,87 @@ def make_image(ckpt: str, fname: str, color=True):
|
||||
ax.axis("off")
|
||||
# ax.axis("square")
|
||||
|
||||
plt.savefig(f"{fname}.png", dpi=300)
|
||||
plt.savefig(f"{fname}.png", **kwargs)
|
||||
|
||||
|
||||
def create_circle(ckpt: str, fname: str):
|
||||
def create_circle(
|
||||
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
|
||||
):
|
||||
if isinstance(ckpt, str):
|
||||
M = ColorTransformerModel.load_from_checkpoint(ckpt)
|
||||
else:
|
||||
M = ckpt
|
||||
|
||||
rgb_tensor, _ = extract_colors()
|
||||
preds = M(rgb_tensor)
|
||||
plot_preds(preds, fname=fname)
|
||||
xkcd_colors, _ = extract_colors()
|
||||
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
|
||||
preds = M(xkcd_colors)
|
||||
rgb_array = xkcd_colors.detach().cpu().numpy()
|
||||
plot_preds(preds, rgb_array, fname=fname, **kwargs)
|
||||
|
||||
|
||||
def plot_preds(preds, fname: str, roll: bool = False):
|
||||
rgb_tensor, _ = extract_colors()
|
||||
rgb_values = rgb_tensor.detach().numpy()
|
||||
rgb_tensor = preprocess_data(rgb_tensor)
|
||||
|
||||
def plot_preds(
|
||||
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)
|
||||
):
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().numpy()
|
||||
preds = preds.detach().cpu().numpy()
|
||||
sorted_inds = np.argsort(preds.ravel())
|
||||
colors = rgb_values[sorted_inds]
|
||||
colors = rgb_values[sorted_inds, :3]
|
||||
if roll:
|
||||
# find white in colors, put it first.
|
||||
white = np.array([1, 1, 1])
|
||||
white_idx = np.where((colors == white).all(axis=1))[0][0]
|
||||
colors = np.roll(colors, -white_idx, axis=0)
|
||||
white_idx = np.where((colors == white).all(axis=1))
|
||||
if white_idx:
|
||||
white_idx = white_idx[0][0]
|
||||
colors = np.roll(colors, -white_idx, axis=0)
|
||||
else:
|
||||
print("no white, skipping")
|
||||
# print(white_idx, colors[:2])
|
||||
|
||||
N = len(colors)
|
||||
# Create a plot with these hues in a circle
|
||||
fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True))
|
||||
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
||||
|
||||
# Each wedge in the circle
|
||||
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2
|
||||
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
|
||||
width = 2 * np.pi / (N) # equal size for each wedge
|
||||
|
||||
for i in range(N):
|
||||
ax.bar(theta[i], 1, width=width, color=colors[i], bottom=0.0)
|
||||
ax.bar(
|
||||
# 2 * np.pi * preds[i],
|
||||
theta[i],
|
||||
height=1,
|
||||
width=width,
|
||||
edgecolor=colors[i],
|
||||
linewidth=0.25,
|
||||
# facecolor=[rgb_values[i][1]]*3,
|
||||
# facecolor=rgb_values[i],
|
||||
facecolor=colors[i],
|
||||
bottom=0.0,
|
||||
zorder=1,
|
||||
alpha=1,
|
||||
align="edge",
|
||||
)
|
||||
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
fig.tight_layout()
|
||||
plt.savefig(f"{fname}.png", dpi=150)
|
||||
radius = 1
|
||||
ax.set_ylim(-radius, radius)
|
||||
ax.set_xlim(-radius, radius)
|
||||
|
||||
# Overlay white circle
|
||||
inner_radius = 1 / 3
|
||||
circle = patches.Circle(
|
||||
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.savefig(
|
||||
f"{fname}.png", dpi=dpi, transparent=False, pad_inches=0, bbox_inches="tight"
|
||||
)
|
||||
plt.close()
|
||||
|
||||
|
||||
@ -91,18 +129,22 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# make the following accept a list of arguments
|
||||
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1])
|
||||
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
|
||||
parser.add_argument(
|
||||
"--dpi", type=int, default=150, help="Resolution for saved image."
|
||||
)
|
||||
parser.add_argument("--figsize", type=int, default=3, help="Figure size")
|
||||
args = parser.parse_args()
|
||||
versions = args.version
|
||||
for v in versions:
|
||||
name = f"out/v{v}"
|
||||
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
|
||||
ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
|
||||
ckpt_path = f"/teamspace/studios/colors-refactor-secondary/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
|
||||
ckpt = glob.glob(ckpt_path)
|
||||
if len(ckpt) > 0:
|
||||
ckpt = ckpt[-1]
|
||||
print(f"Generating image for checkpoint: {ckpt}")
|
||||
create_circle(ckpt, fname=name)
|
||||
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2)
|
||||
else:
|
||||
print(f"No checkpoint found for version {v}")
|
||||
# make_image(ckpt, fname=name + "b", color=False)
|
||||
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 28 KiB |
235
color_poster.py
Normal file
235
color_poster.py
Normal file
@ -0,0 +1,235 @@
|
||||
from typing import List
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# set font by default to courier new
|
||||
plt.rcParams["font.family"] = "PT Mono"
|
||||
# # sort by the proximity to the colors in viridis
|
||||
# # Importing necessary functions
|
||||
|
||||
# # Calculate the proximity of each color in XKCD_COLORS to the viridis colormap
|
||||
# def calculate_proximity_to_viridis(color):
|
||||
# rgb_triple = mcolors.to_rgb(mcolors.XKCD_COLORS[color])
|
||||
# distances = [(sum((a - b)**2 for a, b in zip(rgb_triple, viridis(i))), i) for i in range(256)]
|
||||
# _, closest_viridis_value = min(distances, key=lambda x: x[0]) # Find the viridis color with the minimum distance
|
||||
# return closest_viridis_value / 255 # Normalize to range (0, 1)
|
||||
|
||||
# # Calculate the proximity values for each color
|
||||
# proximity_values = {color: calculate_proximity_to_viridis(color) for color in colors}
|
||||
|
||||
# # Sort the colors based on their proximity values
|
||||
# sorted_colors = sorted(proximity_values.keys(), key=lambda x: proximity_values[x])
|
||||
|
||||
|
||||
def create_color_calibration_image(
|
||||
colors, ppi: List[int] = [100], index=0, solid_capstyle="butt", antialiased=True
|
||||
):
|
||||
first_color = colors[0]
|
||||
last_color = colors[-1]
|
||||
print(f"Processing color range: {first_color} to {last_color}")
|
||||
# Conversion factor: 1 inch = dpi pixels
|
||||
vert_space = 1 / 2 # inches
|
||||
fontsize = 12 # points
|
||||
|
||||
# Figure settings
|
||||
fig_width = 4.0 # inches
|
||||
fig_height = len(colors) * vert_space
|
||||
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
||||
|
||||
# plot a vertical black rectangle between x=3 and x=4
|
||||
ax.axvspan(0.25, 0.75, facecolor="black")
|
||||
# ax.axvspan(-0.325-0.175, -0.125, facecolor="black")
|
||||
# Loop through each color
|
||||
if index % 2 == 0:
|
||||
skip = -1
|
||||
else:
|
||||
skip = 1
|
||||
for idx, color in enumerate(colors[::skip]):
|
||||
# print(color)
|
||||
y_position = 0.5 - 0.125 + idx * vert_space # Offset each color by 0.3 inches
|
||||
|
||||
# Draw color name
|
||||
rgb_triple = mcolors.to_rgb(mcolors.XKCD_COLORS[color])
|
||||
# round to 4 decimal places
|
||||
rgb_triple = tuple(round(x, 4) for x in rgb_triple)
|
||||
# format as string with fixed decimal places
|
||||
rgb_triple = ", ".join([f"{x:1.4f}" for x in rgb_triple])
|
||||
hex_code = mcolors.to_hex(mcolors.XKCD_COLORS[color])
|
||||
ax.text(
|
||||
1.0,
|
||||
y_position,
|
||||
color.replace("xkcd:", ""),
|
||||
va="center",
|
||||
fontsize=fontsize,
|
||||
# bbox=dict(facecolor='gray', alpha=0.21),
|
||||
)
|
||||
ax.text(
|
||||
1.25,
|
||||
y_position - 1.5 * fontsize / 72,
|
||||
f"{hex_code}\n({rgb_triple})",
|
||||
va="center",
|
||||
fontsize=6,
|
||||
# bbox=dict(facecolor='gray', alpha=0.33),
|
||||
)
|
||||
# ax.text(
|
||||
# 1.125,
|
||||
# y_position - 2 * fontsize / 72,
|
||||
# f"{hex_code}",
|
||||
# va="center",
|
||||
# fontsize=fontsize * 0.6,
|
||||
# )
|
||||
|
||||
# Draw color square
|
||||
rect_height = 0.25
|
||||
rect_width = 0.25
|
||||
square_x_start = 0.75 - rect_width / 2 # Offset from the left
|
||||
square_y_start = y_position - rect_height # + 0.25
|
||||
ax.add_patch(
|
||||
plt.Rectangle(
|
||||
(square_x_start, square_y_start), rect_width, rect_height, fc=color
|
||||
)
|
||||
)
|
||||
|
||||
# Draw lines with varying stroke sizes
|
||||
line_x_start = 0 # Offset from the left
|
||||
line_length = 0.5
|
||||
line_y_start = y_position - rect_height - 0.075
|
||||
line_widths = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2.0, 3.0][::-1]
|
||||
for width in line_widths:
|
||||
ax.plot(
|
||||
[line_x_start, line_x_start + line_length],
|
||||
[line_y_start, line_y_start],
|
||||
linewidth=width,
|
||||
color=color,
|
||||
antialiased=antialiased,
|
||||
solid_capstyle=solid_capstyle,
|
||||
)
|
||||
line_y_start += 0.05
|
||||
|
||||
# # now repeat but vertical of height 1
|
||||
# line_x_start = 3.125 - 0.05 # Offset from the left
|
||||
# line_y_start = y_position + dpi / 960
|
||||
# for width in line_widths:
|
||||
# ax.plot(
|
||||
# [line_x_start, line_x_start],
|
||||
# [line_y_start, line_y_start - 0.5],
|
||||
# linewidth=width,
|
||||
# color=color,
|
||||
# antialiased=True,
|
||||
# )
|
||||
# ax.plot(
|
||||
# [0.5 + line_x_start, 0.5 + line_x_start],
|
||||
# [line_y_start, line_y_start - 0.5],
|
||||
# linewidth=width,
|
||||
# color=color,
|
||||
# antialiased=True,
|
||||
# )
|
||||
# line_x_start += 0.05
|
||||
|
||||
# Save the image
|
||||
# Remove axes
|
||||
ax.axis("off")
|
||||
|
||||
# ax.set_aspect("equal")
|
||||
# plt.tight_layout(pad=0)
|
||||
ax.set_ylim([0, fig_height])
|
||||
# ax.set_xlim([0, fig_width])
|
||||
ax.set_xlim([0, fig_width])
|
||||
# pad = 0.108
|
||||
pad = 0
|
||||
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
||||
output_paths = []
|
||||
for _dpi in ppi:
|
||||
_out_path = f"/tmp/color_calibration-{_dpi}_{index:02d}.png"
|
||||
plt.savefig(_out_path, pad_inches=pad, dpi=_dpi)
|
||||
output_paths.append(_out_path)
|
||||
# plt.show()
|
||||
plt.close()
|
||||
|
||||
return output_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--rows", type=int, default=73, help="Number of entries per column"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir", type=str, default="/Volumes/TMP/tests", help="Directory to save images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--kind",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["hsv", "lex", "lab", "umap"],
|
||||
help="Kinds of sorting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppi", action="append", type=int, default=[300], help="Pixels per inch"
|
||||
)
|
||||
parser.add_argument("--aliased", action="store_true", help="Disable antialiasing")
|
||||
parser.add_argument(
|
||||
"--capstyle", type=str, default="butt", help="Capstyle of lines"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
COLUMN_LENGTH = args.rows
|
||||
KINDS = args.kind
|
||||
PPIS = args.ppi
|
||||
DIR = args.dir
|
||||
ANTIALIASED = not args.aliased
|
||||
CAPSTYLE = args.capstyle
|
||||
|
||||
# COLUMN_LENGTH = 73 # results in 13 unfiltered columns (perfect)
|
||||
# COLUMN_LENGTH = (
|
||||
# 106 # results in 9 unfiltered columns (last one short), square-ish image
|
||||
# )
|
||||
OMITTED_COLORS = [
|
||||
# "black",
|
||||
# "white",
|
||||
# "poop",
|
||||
# "poo brown",
|
||||
# "shit",
|
||||
# "shit brown",
|
||||
]
|
||||
# OMITTED_COLORS = list(map(lambda s: f"xkcd:{s}", OMITTED_COLORS))
|
||||
# KIND = "hsv" # choose from umap, hsv
|
||||
for KIND in KINDS:
|
||||
colors = list(mcolors.XKCD_COLORS.keys())
|
||||
sorted_indices = np.load(f"scripts/{KIND}_sorted_indices.npy")
|
||||
sorted_colors = [colors[idx] for idx in sorted_indices]
|
||||
|
||||
colors = sorted_colors
|
||||
colors = [c for c in colors if c not in OMITTED_COLORS]
|
||||
print(f"Total number of colors: {len(colors)}")
|
||||
|
||||
chunks = [
|
||||
colors[i : i + COLUMN_LENGTH] for i in range(0, len(colors), COLUMN_LENGTH)
|
||||
]
|
||||
|
||||
for idx, color_part in enumerate(chunks):
|
||||
image_path = create_color_calibration_image(
|
||||
colors=color_part,
|
||||
ppi=PPIS,
|
||||
index=idx,
|
||||
antialiased=ANTIALIASED,
|
||||
solid_capstyle=CAPSTYLE,
|
||||
)
|
||||
os.system(f"identify {image_path[0]}")
|
||||
|
||||
for PPI in PPIS:
|
||||
# use imagemagick to stitch together the images horizontally
|
||||
os.system(
|
||||
f"convert +append /tmp/color_calibration-{PPI}_*.png /tmp/color_calibration-{PPI}.png"
|
||||
)
|
||||
os.system(f"rm /tmp/color_calibration-{PPI}_*")
|
||||
print(f"Final image saved to /tmp/color_calibration-{PPI}.png")
|
||||
os.system(
|
||||
f"mkdir -p {DIR} && cp /tmp/color_calibration-{PPI}.png {DIR}/xkcd_{COLUMN_LENGTH}_{KIND}_{PPI}.png"
|
||||
)
|
||||
print(f"Copied to {DIR}/xkcd_{COLUMN_LENGTH}_{KIND}_{PPI}.png")
|
@ -4,39 +4,42 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
from utils import extract_colors, preprocess_data
|
||||
|
||||
|
||||
def create_dataloader(N: int = 50, **kwargs):
|
||||
rgb_tensor, _ = extract_colors()
|
||||
rgb_tensor = preprocess_data(rgb_tensor)
|
||||
def create_random_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
|
||||
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
|
||||
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
|
||||
# Creating a dataset and data loader
|
||||
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
|
||||
train_dataloader = DataLoader(dataset, **kwargs)
|
||||
return train_dataloader
|
||||
|
||||
|
||||
def create_gray_supplement(N: int = 50):
|
||||
def create_gray_supplement(N: int = 50, skip: bool = True):
|
||||
linear_space = torch.linspace(0, 1, N)
|
||||
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
|
||||
gray_tensor = preprocess_data(gray_tensor)
|
||||
gray_tensor = preprocess_data(gray_tensor, skip=skip)
|
||||
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
|
||||
|
||||
|
||||
def create_named_dataloader(N: int = 0, **kwargs):
|
||||
def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
|
||||
rgb_tensor, xkcd_color_names = extract_colors()
|
||||
rgb_tensor = preprocess_data(rgb_tensor)
|
||||
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
|
||||
# Creating a dataset with RGB values and their corresponding color names
|
||||
dataset_with_names = [
|
||||
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
|
||||
for i in range(len(rgb_tensor))
|
||||
]
|
||||
if N > 0:
|
||||
dataset_with_names += create_gray_supplement(N)
|
||||
dataset_with_names += create_gray_supplement(N, skip=skip)
|
||||
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
|
||||
return train_dataloader_with_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 4
|
||||
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)
|
||||
train_dataloader = create_random_dataloader(
|
||||
N=1e6, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
print(len(train_dataloader.dataset))
|
||||
train_dataloader_with_names = create_named_dataloader(
|
||||
batch_size=batch_size, shuffle=True
|
||||
)
|
||||
|
103
datamodule.py
Normal file
103
datamodule.py
Normal file
@ -0,0 +1,103 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from matplotlib.colors import rgb_to_hsv
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils import extract_colors, preprocess_data
|
||||
|
||||
|
||||
class ColorDataModule(L.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
val_size: int = 10_000,
|
||||
train_size=0,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.val_size = val_size
|
||||
self.train_size = train_size
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def prepare_data(self):
|
||||
# no state. called from main process.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_hue(cls, v: torch.Tensor) -> torch.Tensor:
|
||||
return torch.tensor([rgb_to_hsv(v)[0]], dtype=torch.float32)
|
||||
|
||||
@classmethod
|
||||
def get_random_colors(cls, size: int):
|
||||
train_rgb = torch.rand((int(size), 3), dtype=torch.float32)
|
||||
train_rgb = preprocess_data(train_rgb, skip=True)
|
||||
return [(c, cls.get_hue(c)) for c in train_rgb]
|
||||
|
||||
@classmethod
|
||||
def get_xkcd_colors(cls):
|
||||
rgb_tensor, xkcd_color_names = extract_colors()
|
||||
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
|
||||
# return [
|
||||
# (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
|
||||
# for i in range(len(rgb_tensor))
|
||||
# ]
|
||||
return [(c, cls.get_hue(c)) for c in rgb_tensor]
|
||||
|
||||
def setup(self, stage: str):
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == "fit":
|
||||
self.color_val = self.get_random_colors(self.val_size)
|
||||
if self.train_size > 0:
|
||||
self.color_train = self.get_random_colors(self.train_size)
|
||||
else:
|
||||
self.color_train = self.get_xkcd_colors()
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == "test":
|
||||
self.color_test = self.get_random_colors(self.val_size)
|
||||
|
||||
if stage == "predict": # for visualizing
|
||||
self.color_predict = self.get_xkcd_colors()
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.color_train,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self.color_val,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self.color_test,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(
|
||||
self.color_predict,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def teardown(self, stage: str):
|
||||
# Used to clean-up when the run is finished
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cdm = ColorDataModule()
|
||||
cdm.setup("train")
|
||||
print(cdm)
|
@ -1,25 +0,0 @@
|
||||
,batch_size,alpha,learning_rate
|
||||
0,32.0,0.3,0.0001
|
||||
1,32.0,0.3,0.01
|
||||
2,32.0,0.9,1e-06
|
||||
3,32.0,0.7,0.001
|
||||
4,64.0,0.5,0.001
|
||||
5,64.0,0.1,1e-06
|
||||
6,32.0,0.1,0.001
|
||||
7,128.0,0.5,1e-06
|
||||
8,128.0,0.7,0.001
|
||||
9,128.0,0.9,1e-05
|
||||
10,128.0,0.1,1e-06
|
||||
11,128.0,0.3,1e-06
|
||||
12,64.0,0.3,0.01
|
||||
13,64.0,0.1,1e-06
|
||||
14,128.0,0.5,0.001
|
||||
15,32.0,0.3,1e-05
|
||||
16,32.0,0.7,1e-06
|
||||
17,32.0,0.3,1e-06
|
||||
18,64.0,0.3,0.0001
|
||||
19,64.0,0.3,1e-06
|
||||
20,128.0,0.5,1e-05
|
||||
21,32.0,0.1,0.01
|
||||
22,64.0,0.1,1e-05
|
||||
23,64.0,0.3,0.001
|
|
BIN
hsv.png
BIN
hsv.png
Binary file not shown.
Before Width: | Height: | Size: 328 KiB After Width: | Height: | Size: 2.7 MiB |
8
hsv.py
8
hsv.py
@ -7,8 +7,12 @@ from utils import extract_colors
|
||||
if __name__ == "__main__":
|
||||
rgb_tensor, _ = extract_colors()
|
||||
xkcd_rgb = rgb_tensor.numpy()
|
||||
# xkcd_rgb = np.random.rand(1000, 3)
|
||||
xkcd_hsv = rgb_to_hsv(xkcd_rgb)
|
||||
plot_preds(xkcd_hsv[:, 0], fname="hsv")
|
||||
rgb = np.eye(3)
|
||||
plot_preds(
|
||||
xkcd_hsv[:, 0], xkcd_rgb, fname="hsv", roll=True, dpi=300, figsize=(6, 6)
|
||||
)
|
||||
rgb = np.vstack([np.eye(3), np.eye(3) + np.eye(3)[:, [1, 2, 0]]])
|
||||
print("Pure RGB in Hue-Space:")
|
||||
print(rgb)
|
||||
print(rgb_to_hsv(rgb)[:, 0])
|
||||
|
86
hsv1.txt
Normal file
86
hsv1.txt
Normal file
@ -0,0 +1,86 @@
|
||||
# lightning.pytorch==2.1.3
|
||||
seed_everything: 1387
|
||||
trainer:
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: null
|
||||
logger: null
|
||||
callbacks:
|
||||
- class_path: callbacks.SaveImageCallback
|
||||
init_args:
|
||||
save_interval: 0
|
||||
final_dir: out
|
||||
fast_dev_run: false
|
||||
max_epochs: 10
|
||||
min_epochs: 10
|
||||
max_steps: -1
|
||||
min_steps: null
|
||||
max_time: null
|
||||
limit_train_batches: null
|
||||
limit_val_batches: 50
|
||||
limit_test_batches: null
|
||||
limit_predict_batches: null
|
||||
overfit_batches: 0.0
|
||||
val_check_interval: null
|
||||
check_val_every_n_epoch: 1
|
||||
num_sanity_val_steps: null
|
||||
log_every_n_steps: 3
|
||||
enable_checkpointing: null
|
||||
enable_progress_bar: null
|
||||
enable_model_summary: null
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: null
|
||||
gradient_clip_algorithm: null
|
||||
deterministic: null
|
||||
benchmark: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: true
|
||||
profiler: null
|
||||
detect_anomaly: false
|
||||
barebones: false
|
||||
plugins: null
|
||||
sync_batchnorm: false
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
default_root_dir: null
|
||||
model:
|
||||
transform: tanh
|
||||
width: 128
|
||||
depth: 4
|
||||
bias: true
|
||||
alpha: 0.0
|
||||
data:
|
||||
val_size: 10000
|
||||
train_size: 10000
|
||||
batch_size: 256
|
||||
num_workers: 3
|
||||
ckpt_path: null
|
||||
optimizer:
|
||||
class_path: torch.optim.Adam
|
||||
init_args:
|
||||
lr: 0.001
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
weight_decay: 0.0
|
||||
amsgrad: false
|
||||
foreach: null
|
||||
maximize: false
|
||||
capturable: false
|
||||
differentiable: false
|
||||
fused: null
|
||||
lr_scheduler:
|
||||
class_path: lightning.pytorch.cli.ReduceLROnPlateau
|
||||
init_args:
|
||||
monitor: hp_metric
|
||||
mode: min
|
||||
factor: 0.05
|
||||
patience: 5
|
||||
threshold: 0.0001
|
||||
threshold_mode: rel
|
||||
cooldown: 10
|
||||
min_lr: 0.0
|
||||
eps: 1.0e-08
|
||||
verbose: true
|
86
hsv2.txt
Normal file
86
hsv2.txt
Normal file
@ -0,0 +1,86 @@
|
||||
# lightning.pytorch==2.1.3
|
||||
seed_everything: 31
|
||||
trainer:
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: null
|
||||
logger: null
|
||||
callbacks:
|
||||
- class_path: callbacks.SaveImageCallback
|
||||
init_args:
|
||||
save_interval: 0
|
||||
final_dir: out
|
||||
fast_dev_run: false
|
||||
max_epochs: 10
|
||||
min_epochs: 10
|
||||
max_steps: -1
|
||||
min_steps: null
|
||||
max_time: null
|
||||
limit_train_batches: null
|
||||
limit_val_batches: 50
|
||||
limit_test_batches: null
|
||||
limit_predict_batches: null
|
||||
overfit_batches: 0.0
|
||||
val_check_interval: null
|
||||
check_val_every_n_epoch: 1
|
||||
num_sanity_val_steps: null
|
||||
log_every_n_steps: 3
|
||||
enable_checkpointing: null
|
||||
enable_progress_bar: null
|
||||
enable_model_summary: null
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: null
|
||||
gradient_clip_algorithm: null
|
||||
deterministic: null
|
||||
benchmark: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: true
|
||||
profiler: null
|
||||
detect_anomaly: false
|
||||
barebones: false
|
||||
plugins: null
|
||||
sync_batchnorm: false
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
default_root_dir: null
|
||||
model:
|
||||
transform: tanh
|
||||
width: 256
|
||||
depth: 8
|
||||
bias: true
|
||||
alpha: 0.0
|
||||
data:
|
||||
val_size: 10000
|
||||
train_size: 10000
|
||||
batch_size: 256
|
||||
num_workers: 3
|
||||
ckpt_path: null
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 0.001
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
weight_decay: 0.01
|
||||
amsgrad: false
|
||||
maximize: false
|
||||
foreach: null
|
||||
capturable: false
|
||||
differentiable: false
|
||||
fused: null
|
||||
lr_scheduler:
|
||||
class_path: lightning.pytorch.cli.ReduceLROnPlateau
|
||||
init_args:
|
||||
monitor: hp_metric
|
||||
mode: min
|
||||
factor: 0.05
|
||||
patience: 5
|
||||
threshold: 0.0001
|
||||
threshold_mode: rel
|
||||
cooldown: 10
|
||||
min_lr: 0.0
|
||||
eps: 1.0e-08
|
||||
verbose: true
|
86
hsv3.txt
Normal file
86
hsv3.txt
Normal file
@ -0,0 +1,86 @@
|
||||
# lightning.pytorch==2.1.3
|
||||
seed_everything: 1009
|
||||
trainer:
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: null
|
||||
logger: null
|
||||
callbacks:
|
||||
- class_path: callbacks.SaveImageCallback
|
||||
init_args:
|
||||
save_interval: 0
|
||||
final_dir: out
|
||||
fast_dev_run: false
|
||||
max_epochs: 10
|
||||
min_epochs: 10
|
||||
max_steps: -1
|
||||
min_steps: null
|
||||
max_time: null
|
||||
limit_train_batches: null
|
||||
limit_val_batches: 50
|
||||
limit_test_batches: null
|
||||
limit_predict_batches: null
|
||||
overfit_batches: 0.0
|
||||
val_check_interval: null
|
||||
check_val_every_n_epoch: 1
|
||||
num_sanity_val_steps: null
|
||||
log_every_n_steps: 3
|
||||
enable_checkpointing: null
|
||||
enable_progress_bar: null
|
||||
enable_model_summary: null
|
||||
accumulate_grad_batches: 1
|
||||
gradient_clip_val: null
|
||||
gradient_clip_algorithm: null
|
||||
deterministic: null
|
||||
benchmark: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: true
|
||||
profiler: null
|
||||
detect_anomaly: false
|
||||
barebones: false
|
||||
plugins: null
|
||||
sync_batchnorm: false
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
default_root_dir: null
|
||||
model:
|
||||
transform: tanh
|
||||
width: 512
|
||||
depth: 2
|
||||
bias: true
|
||||
alpha: 0.75
|
||||
data:
|
||||
val_size: 10000
|
||||
train_size: 10000
|
||||
batch_size: 256
|
||||
num_workers: 3
|
||||
ckpt_path: null
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 0.001
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.999
|
||||
eps: 1.0e-08
|
||||
weight_decay: 0.01
|
||||
amsgrad: false
|
||||
maximize: false
|
||||
foreach: null
|
||||
capturable: false
|
||||
differentiable: false
|
||||
fused: null
|
||||
lr_scheduler:
|
||||
class_path: lightning.pytorch.cli.ReduceLROnPlateau
|
||||
init_args:
|
||||
monitor: hp_metric
|
||||
mode: min
|
||||
factor: 0.05
|
||||
patience: 5
|
||||
threshold: 0.0001
|
||||
threshold_mode: rel
|
||||
cooldown: 10
|
||||
min_lr: 0.0
|
||||
eps: 1.0e-08
|
||||
verbose: true
|
41
index.html
Normal file
41
index.html
Normal file
@ -0,0 +1,41 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Discover the Rainbow</title>
|
||||
<style>
|
||||
body {
|
||||
text-align: center;
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
h1 {
|
||||
margin-top: 50px;
|
||||
}
|
||||
.links {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.links a {
|
||||
display: block;
|
||||
margin: 10px 0;
|
||||
font-size: 18px;
|
||||
color: blue;
|
||||
text-decoration: none;
|
||||
}
|
||||
.links a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Discover the Rainbow</h1>
|
||||
<div class="links">
|
||||
<a href="./out">Current</a>
|
||||
<a href="./out1">Iteration 1</a>
|
||||
<a href="./out2">Iteration 2</a>
|
||||
<a href="./out3">Iteration 3</a>
|
||||
<a href="./out4">Iteration 4 (good refactor, searching for mix of supervision)</a>
|
||||
<a href="./out">Iteration 5 (all supervised)</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
64
losses.py
64
losses.py
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from utils import PURE_RGB
|
||||
from utils import RGBMYC_ANCHOR
|
||||
|
||||
# def smoothness_loss(outputs):
|
||||
# # Sort outputs for smoothness calculation
|
||||
@ -17,55 +17,61 @@ from utils import PURE_RGB
|
||||
# return smoothness_loss
|
||||
|
||||
|
||||
def preservation_loss(inputs, outputs):
|
||||
# Distance Preservation Component
|
||||
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
|
||||
# Distance Preservation Component (or scaled euclidean if given targets)
|
||||
# Encourages the model to keep relative distances from the RGB space in the transformed space
|
||||
if target_inputs is None:
|
||||
target_inputs = inputs
|
||||
else:
|
||||
assert target_outputs is not None
|
||||
if target_outputs is None:
|
||||
target_outputs = outputs
|
||||
|
||||
# Calculate RGB Norm
|
||||
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
|
||||
# max_rgb_distance = 1
|
||||
rgb_norm = (
|
||||
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1))
|
||||
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
|
||||
/ max_rgb_distance
|
||||
)
|
||||
rgb_norm = (
|
||||
rgb_norm % 1
|
||||
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
|
||||
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
|
||||
# rgb_norm = rgb_norm % 1 # i think this is why yellow and blue end up adjacent.
|
||||
# yes it connects black and white, but also complimentary colors to primary
|
||||
# print(rgb_norm)
|
||||
|
||||
# Calculate 1D Space Norm (modulo 1 to account for circularity)
|
||||
transformed_norm = torch.triu(
|
||||
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1)
|
||||
)
|
||||
transformed_norm = circle_norm(outputs, target_outputs) * 2
|
||||
|
||||
diff = torch.abs(rgb_norm - transformed_norm)
|
||||
# print(diff)
|
||||
diff = torch.pow(rgb_norm - transformed_norm, 2)
|
||||
# N = len(outputs)
|
||||
# N = (N * (N - 1)) / 2
|
||||
N = torch.count_nonzero(rgb_norm)
|
||||
return torch.sum(diff) / N
|
||||
|
||||
return torch.mean(diff)
|
||||
|
||||
def circle_norm(vector, other_vector):
|
||||
# Assumes vectors are of shape (N,1)
|
||||
diff = torch.abs(vector - other_vector.T)
|
||||
loss_a = torch.triu(diff)
|
||||
loss_b = torch.triu(torch.abs(1 - diff))
|
||||
loss = torch.minimum(loss_a, loss_b)
|
||||
return loss
|
||||
|
||||
|
||||
def separation_loss(red, green, blue):
|
||||
# Separation Component
|
||||
# TODO: remove
|
||||
# Encourages the model to keep R, G, B values equally separated in the transformed space
|
||||
red, green, blue = red % 1, green % 1, blue % 1
|
||||
red_green_distance = torch.min(
|
||||
torch.abs((red - green)), torch.abs((1 + red - green))
|
||||
)
|
||||
red_blue_distance = torch.min(torch.abs((red - blue)), torch.abs((1 + red - blue)))
|
||||
green_blue_distance = torch.min(
|
||||
torch.abs((green - blue)), torch.abs((1 + green - blue))
|
||||
)
|
||||
# print(red_green_distance, red_blue_distance, green_blue_distance)
|
||||
# we want these distances to be equal to one another
|
||||
return (
|
||||
torch.abs(red_green_distance - red_blue_distance)
|
||||
+ torch.abs(red_green_distance - green_blue_distance)
|
||||
+ torch.abs(red_blue_distance - green_blue_distance)
|
||||
)
|
||||
red_loss = torch.abs(0 - red)
|
||||
green_loss = torch.abs(1 / 3 - green) / (2 / 3)
|
||||
blue_loss = torch.abs(2 / 3 - blue) / (2 / 3)
|
||||
return red_loss + green_loss + blue_loss
|
||||
|
||||
|
||||
def calculate_separation_loss(model):
|
||||
# TODO: remove
|
||||
# Wrapper function to calculate separation loss
|
||||
outputs = model(PURE_RGB)
|
||||
outputs = model(RGBMYC_ANCHOR.to(model.device))
|
||||
red, green, blue = outputs[0], outputs[1], outputs[2]
|
||||
return separation_loss(red, green, blue)
|
||||
|
||||
|
30
main.py
30
main.py
@ -4,10 +4,10 @@ import random
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
|
||||
|
||||
from callbacks import SaveImageCallback
|
||||
from dataloader import create_named_dataloader
|
||||
from dataloader import create_named_dataloader as create_dataloader
|
||||
from model import ColorTransformerModel
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ def parse_args():
|
||||
default=3,
|
||||
help="Number of workers for data loading",
|
||||
)
|
||||
parser.add_argument("--width", type=int, default=128, help="Max width of network.")
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
@ -61,13 +62,13 @@ if __name__ == "__main__":
|
||||
|
||||
seed_everything(args.seed)
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor="hp_metric", # Metric to monitor
|
||||
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
|
||||
patience=24, # Number of epochs with no improvement after which training will be stopped
|
||||
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
|
||||
verbose=True,
|
||||
)
|
||||
# early_stop_callback = EarlyStopping(
|
||||
# monitor="hp_metric", # Metric to monitor
|
||||
# min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
|
||||
# patience=5, # Number of epochs with no improvement after which training will be stopped
|
||||
# mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
|
||||
# verbose=True,
|
||||
# )
|
||||
|
||||
save_img_callback = SaveImageCallback(
|
||||
save_interval=0,
|
||||
@ -76,8 +77,9 @@ if __name__ == "__main__":
|
||||
|
||||
# Initialize data loader with parsed arguments
|
||||
# named_data_loader also has grayscale extras. TODO: remove unnamed
|
||||
train_dataloader = create_named_dataloader(
|
||||
N=0,
|
||||
train_dataloader = create_dataloader(
|
||||
# N=1e5,
|
||||
skip=True,
|
||||
batch_size=args.bs,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
@ -87,6 +89,10 @@ if __name__ == "__main__":
|
||||
alpha=args.alpha,
|
||||
learning_rate=args.lr,
|
||||
batch_size=args.bs,
|
||||
width=args.width,
|
||||
bias=False,
|
||||
transform="relu",
|
||||
depth=1,
|
||||
)
|
||||
|
||||
# Initialize model with parsed arguments
|
||||
@ -95,7 +101,7 @@ if __name__ == "__main__":
|
||||
# Initialize trainer with parsed arguments
|
||||
trainer = pl.Trainer(
|
||||
deterministic=True,
|
||||
callbacks=[early_stop_callback, save_img_callback],
|
||||
callbacks=[save_img_callback],
|
||||
max_epochs=args.max_epochs,
|
||||
log_every_n_steps=args.log_every_n_steps,
|
||||
)
|
||||
|
69
makefile
69
makefile
@ -1,20 +1,75 @@
|
||||
lint:
|
||||
black .
|
||||
isort --profile=black .
|
||||
flake8 --ignore E501,W503 .
|
||||
isort --profile=black *.py
|
||||
flake8 --ignore E501,W503,E203 *.py
|
||||
|
||||
test:
|
||||
python main.py --alpha 2 --lr 2e-4 --max_epochs 200
|
||||
# python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048
|
||||
python newmain.py fit \
|
||||
--seed_everything 21 \
|
||||
--data.batch_size 256 \
|
||||
--data.train_size 0 \
|
||||
--data.val_size 100000 \
|
||||
--model.alpha 0 \
|
||||
--model.width 2048 \
|
||||
--trainer.fast_dev_run 1 \
|
||||
--trainer.min_epochs 1 \
|
||||
--trainer.max_epochs 10 \
|
||||
--trainer.log_every_n_steps 5 \
|
||||
--trainer.check_val_every_n_epoch 1 \
|
||||
--trainer.callbacks callbacks.SaveImageCallback \
|
||||
--trainer.callbacks.init_args.final_dir out \
|
||||
--trainer.callbacks.init_args.save_interval 0 \
|
||||
--optimizer torch.optim.Adam \
|
||||
--optimizer.init_args.lr 0.01 \
|
||||
--lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
|
||||
--lr_scheduler.init_args.patience 5 \
|
||||
--lr_scheduler.init_args.cooldown 10 \
|
||||
--lr_scheduler.init_args.factor 0.05 \
|
||||
--lr_scheduler.init_args.monitor hp_metric \
|
||||
--lr_scheduler.init_args.verbose true \
|
||||
--print_config
|
||||
|
||||
search:
|
||||
python search.py
|
||||
|
||||
help:
|
||||
# python newmain.py fit --help --trainer.callbacks.help
|
||||
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
|
||||
python newmain.py fit --help
|
||||
|
||||
search: lint
|
||||
python newsearch.py
|
||||
|
||||
hsv:
|
||||
python hsv.py
|
||||
|
||||
# TODO: replace this with what we used in day-in-the-life
|
||||
animate:
|
||||
ffmpeg -i lightning_logs/version_258/e%04d.png \
|
||||
-c:v libx264 \
|
||||
-vf "fps=12,format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" \
|
||||
~/animated.mp4
|
||||
|
||||
|
||||
umap:
|
||||
for seed in `seq 0 100`; do \
|
||||
python scripts/sortcolor.py -s umap --dpi 300 --seed $$seed ; \
|
||||
done
|
||||
|
||||
sort_umap:
|
||||
python scripts/sortcolor.py -s umap --dpi 300 --seed 21
|
||||
|
||||
parallel_umap:
|
||||
parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 1000)
|
||||
|
||||
sort_lex:
|
||||
python scripts/sortcolor.py -s lex --dpi 300
|
||||
|
||||
sort_hsv:
|
||||
python scripts/sortcolor.py -s hsv --dpi 300
|
||||
|
||||
clean:
|
||||
rm -rf lightning_logs/*
|
||||
rm out/*.png
|
||||
rm -rf lightning_logs
|
||||
rm -rf .lr_find_*.ckpt
|
||||
rm -f out/*.png out/*.txt
|
||||
rm -rf __pycache__/
|
||||
cp hsv.png out/
|
||||
|
173
model.py
173
model.py
@ -1,106 +1,57 @@
|
||||
import pytorch_lightning as pl
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from losses import calculate_separation_loss, preservation_loss
|
||||
|
||||
# class ColorTransformerModel(pl.LightningModule):
|
||||
# def __init__(self, params):
|
||||
# super().__init__()
|
||||
# self.save_hyperparameters(params)
|
||||
|
||||
# # Model layers
|
||||
# self.layers = nn.Sequential(
|
||||
# nn.Linear(5, 128, bias=False),
|
||||
# nn.Linear(128, 3, bias=False),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(3, 64, bias=False),
|
||||
# nn.Linear(64, 128, bias=False),
|
||||
# nn.Linear(128, 256, bias=False),
|
||||
# nn.Linear(256, 128, bias=False),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(128, 1, bias=False),
|
||||
# )
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.layers(x)
|
||||
# x = (torch.sin(x) + 1) / 2
|
||||
# return x
|
||||
|
||||
# class ColorTransformerModel(pl.LightningModule):
|
||||
# def __init__(self, params):
|
||||
# super().__init__()
|
||||
# self.save_hyperparameters(params)
|
||||
|
||||
# # Embedding layer to expand the input dimensions
|
||||
# self.embedding = nn.Linear(3, 128, bias=False)
|
||||
|
||||
# # Transformer encoder-decoder
|
||||
# encoder = nn.TransformerEncoderLayer(
|
||||
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
|
||||
# )
|
||||
# self.transformer_encoder = nn.TransformerEncoder(
|
||||
# encoder, num_layers=3
|
||||
# )
|
||||
# # lower dimensionality decoder
|
||||
# decoder = nn.TransformerDecoderLayer(
|
||||
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
|
||||
# )
|
||||
# self.transformer_decoder = nn.TransformerDecoder(
|
||||
# decoder, num_layers=3
|
||||
# )
|
||||
|
||||
# # Final linear layer to map back to 1D space
|
||||
# self.final_layer = nn.Linear(128, 1, bias=False)
|
||||
|
||||
# def forward(self, x):
|
||||
# # Embedding the input
|
||||
# x = self.embedding(x)
|
||||
|
||||
# # Adjusting the shape for the transformer
|
||||
# x = x.unsqueeze(1) # Adding a fake sequence dimension
|
||||
|
||||
# # Passing through the transformer
|
||||
# x = self.transformer_encoder(x)
|
||||
|
||||
# # Passing through the decoder
|
||||
# x = self.transformer_decoder(x, memory=x)
|
||||
|
||||
# # Reshape back to original shape
|
||||
# x = x.squeeze(1)
|
||||
|
||||
# # Final linear layer
|
||||
# x = self.final_layer(x)
|
||||
|
||||
# # Apply sigmoid activation to ensure output is in (0, 1)
|
||||
# # x = torch.sigmoid(x)
|
||||
# x = (torch.sin(x) + 1) / 2
|
||||
# return x
|
||||
from losses import circle_norm, preservation_loss # noqa: F401
|
||||
from utils import RGBMYC_ANCHOR
|
||||
|
||||
|
||||
class ColorTransformerModel(pl.LightningModule):
|
||||
def __init__(self, params):
|
||||
class ColorTransformerModel(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
transform: str = "relu",
|
||||
width: int = 128,
|
||||
depth: int = 1,
|
||||
bias: bool = False,
|
||||
alpha: float = 0,
|
||||
lr: float = 0.01,
|
||||
loop: bool = False,
|
||||
dropout=0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(params)
|
||||
self.save_hyperparameters()
|
||||
|
||||
if self.hparams.transform.lower() == "tanh":
|
||||
t = nn.Tanh
|
||||
elif self.hparams.transform.lower() == "relu":
|
||||
t = nn.ReLU
|
||||
|
||||
w = self.hparams.width
|
||||
d = self.hparams.depth
|
||||
bias = self.hparams.bias
|
||||
if self.hparams.loop:
|
||||
midlayers = []
|
||||
midlayers += [nn.Linear(w, w, bias=bias), t()] * d
|
||||
else:
|
||||
midlayers = sum(
|
||||
[
|
||||
[nn.Linear(w, w, bias=bias), nn.Dropout(self.dropout), t()]
|
||||
for _ in range(d)
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
# Neural network layers
|
||||
self.network = nn.Sequential(
|
||||
nn.Linear(3, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 1),
|
||||
nn.Linear(3, w, bias=bias),
|
||||
t(),
|
||||
*midlayers,
|
||||
nn.Linear(w, 3, bias=bias),
|
||||
t(),
|
||||
nn.Linear(3, 1, bias=bias),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Pass the input through the network
|
||||
x = self.network(x)
|
||||
# Circular mapping
|
||||
# x = (torch.sin(x) + 1) / 2
|
||||
@ -110,25 +61,51 @@ class ColorTransformerModel(pl.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
||||
outputs = self.forward(inputs)
|
||||
s_loss = calculate_separation_loss(model=self)
|
||||
rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
|
||||
p_loss = preservation_loss(
|
||||
inputs,
|
||||
outputs,
|
||||
# target_inputs=rgb_tensor,
|
||||
# target_outputs=self.forward(rgb_tensor),
|
||||
)
|
||||
alpha = self.hparams.alpha
|
||||
loss = p_loss + alpha * s_loss
|
||||
self.log("hp_metric", loss)
|
||||
self.log("p_loss", p_loss)
|
||||
self.log("s_loss", s_loss)
|
||||
|
||||
# N = len(outputs)
|
||||
# distance = circle_norm(outputs, labels).mean()
|
||||
distance = torch.norm(outputs - labels).mean()
|
||||
|
||||
# Backprop with this:
|
||||
loss = (1 - alpha) * p_loss + alpha * distance
|
||||
# p_loss is unsupervised (preserve relative distances - either in-batch or to-target)
|
||||
# distance is supervised.
|
||||
self.log("hp_metric", distance)
|
||||
|
||||
# Log all losses individually
|
||||
self.log("train_pres", p_loss)
|
||||
self.log("train_mse", distance)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
inputs, labels = batch # these are true HSV labels - no learning allowed.
|
||||
outputs = self.forward(inputs)
|
||||
# distance = torch.minimum(
|
||||
# torch.norm(outputs - labels), torch.norm(1 + outputs - labels)
|
||||
# )
|
||||
distance = torch.norm(outputs - labels)
|
||||
mean_loss = torch.mean(distance)
|
||||
max_loss = torch.max(distance)
|
||||
self.log("val_mse", mean_loss)
|
||||
self.log("val_max", max_loss)
|
||||
return mean_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(
|
||||
self.parameters(),
|
||||
lr=self.hparams.learning_rate,
|
||||
lr=self.hparams.lr,
|
||||
)
|
||||
lr_scheduler = ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.05, patience=10, cooldown=20, verbose=True
|
||||
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
|
15
newmain.py
Normal file
15
newmain.py
Normal file
@ -0,0 +1,15 @@
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
|
||||
# from callbacks import SaveImageCallback
|
||||
from datamodule import ColorDataModule
|
||||
from model import ColorTransformerModel
|
||||
|
||||
|
||||
def cli_main():
|
||||
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841
|
||||
# note: don't call fit!!
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
# note: it is good practice to implement the CLI in a function and call it in the main if block
|
113
newsearch.py
Normal file
113
newsearch.py
Normal file
@ -0,0 +1,113 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from random import sample, seed
|
||||
|
||||
import numpy as np # noqa: F401
|
||||
from lightning_sdk import Machine, Studio # noqa: F401
|
||||
|
||||
# consistency of randomly sampled experiments.
|
||||
seed(19920921)
|
||||
|
||||
NUM_JOBS = 100
|
||||
|
||||
# reference to the current studio
|
||||
# if you run outside of Lightning, you can pass the Studio name
|
||||
# studio = Studio()
|
||||
|
||||
# use the jobs plugin
|
||||
# studio.install_plugin("jobs")
|
||||
# job_plugin = studio.installed_plugins["jobs"]
|
||||
|
||||
# do a sweep over learning rates
|
||||
|
||||
# Define the ranges or sets of values for each hyperparameter
|
||||
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
|
||||
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
|
||||
learning_rate_values = [1e-3]
|
||||
# learning_rate_values = [5e-4]
|
||||
|
||||
# alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
|
||||
alpha_values = [0]
|
||||
# widths = [2**k for k in range(4, 13)]
|
||||
# depths = [1, 2, 4, 8, 16]
|
||||
widths, depths = [512], [4]
|
||||
|
||||
batch_size_values = [256]
|
||||
max_epochs_values = [100]
|
||||
seeds = list(range(21, 1992))
|
||||
optimizers = [
|
||||
# "Adagrad",
|
||||
"Adam",
|
||||
# "SGD",
|
||||
# "AdamW",
|
||||
# "LBFGS",
|
||||
# "RAdam",
|
||||
# "RMSprop",
|
||||
# "Adadelta",
|
||||
]
|
||||
|
||||
# Generate all possible combinations of hyperparameters
|
||||
all_params = [
|
||||
(alpha, lr, bs, me, s, w, d, opt)
|
||||
for alpha in alpha_values
|
||||
for lr in learning_rate_values
|
||||
for bs in batch_size_values
|
||||
for me in max_epochs_values
|
||||
for s in seeds
|
||||
for w in widths
|
||||
for d in depths
|
||||
for opt in optimizers
|
||||
]
|
||||
|
||||
|
||||
# perform random search with a limit
|
||||
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
|
||||
|
||||
# --trainer.callbacks+ lightning.pytorch.callbacks.EarlyStopping \
|
||||
# --trainer.callbacks.init_args.monitor hp_metric \
|
||||
|
||||
for idx, params in enumerate(search_params):
|
||||
a, lr, bs, me, s, w, d, opt = params
|
||||
# cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}"
|
||||
cmd = f"""
|
||||
python newmain.py fit \
|
||||
--seed_everything {s} \
|
||||
--data.batch_size {bs} \
|
||||
--data.train_size 0 \
|
||||
--data.val_size 10000 \
|
||||
--model.alpha {a} \
|
||||
--model.width {w} \
|
||||
--model.depth {d} \
|
||||
--model.bias true \
|
||||
--model.loop true \
|
||||
--model.transform tanh \
|
||||
--trainer.min_epochs 10 \
|
||||
--trainer.max_epochs {me} \
|
||||
--trainer.log_every_n_steps 3 \
|
||||
--trainer.check_val_every_n_epoch 1 \
|
||||
--trainer.limit_val_batches 50 \
|
||||
--trainer.callbacks callbacks.SaveImageCallback \
|
||||
--trainer.callbacks.init_args.final_dir out \
|
||||
--trainer.callbacks.init_args.save_interval 0 \
|
||||
--optimizer torch.optim.{opt} \
|
||||
--optimizer.init_args.lr {lr} \
|
||||
--trainer.callbacks+ lightning.pytorch.callbacks.LearningRateFinder
|
||||
# --lr_scheduler lightning.pytorch.cli.ReduceLROnPlateau \
|
||||
# --lr_scheduler.init_args.monitor hp_metric \
|
||||
# --lr_scheduler.init_args.factor 0.05 \
|
||||
# --lr_scheduler.init_args.patience 5 \
|
||||
# --lr_scheduler.init_args.cooldown 10 \
|
||||
# --lr_scheduler.init_args.verbose true
|
||||
"""
|
||||
# job_name = f"color2_{bs}_{a}_{lr:2.2e}"
|
||||
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
|
||||
print(f"Running {params}: {cmd}")
|
||||
try:
|
||||
# Run the command and wait for it to complete
|
||||
# subprocess.run(test_cmd, shell=True, check=True)
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted by user")
|
||||
sys.exit(1)
|
||||
# except subprocess.CalledProcessError:
|
||||
# pass
|
@ -72,8 +72,13 @@
|
||||
function loadImages() {
|
||||
var gallery = document.getElementById('gallery');
|
||||
|
||||
for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100
|
||||
let imageName = 'v' + i + '.png';
|
||||
for (var i = 0; i < 200; i++) { // Changed from i <= 100 to i < 100
|
||||
let imageName;
|
||||
if (i == -21) {
|
||||
imageName = 'hsv.png';
|
||||
} else {
|
||||
imageName = 'v' + i + '.png';
|
||||
}
|
||||
let img = document.createElement('img');
|
||||
img.src = imageName;
|
||||
img.onerror = function () { this.style.display = 'none'; };
|
||||
|
210
requirements.txt
Normal file
210
requirements.txt
Normal file
@ -0,0 +1,210 @@
|
||||
absl-py==2.0.0
|
||||
aiobotocore==2.11.0
|
||||
aiofiles==22.1.0
|
||||
aiohttp==3.9.1
|
||||
aioitertools==0.11.0
|
||||
aiosignal==1.3.1
|
||||
aiosqlite==0.19.0
|
||||
annotated-types==0.6.0
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anyio==4.2.0
|
||||
argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==21.2.0
|
||||
arrow==1.3.0
|
||||
asttokens==2.4.1
|
||||
async-timeout==4.0.3
|
||||
attrs==23.2.0
|
||||
Babel==2.14.0
|
||||
backoff==2.2.1
|
||||
beautifulsoup4==4.12.2
|
||||
bitsandbytes==0.42.0
|
||||
black==23.12.1
|
||||
bleach==6.1.0
|
||||
blessed==1.20.0
|
||||
boto3==1.34.17
|
||||
botocore==1.34.17
|
||||
cachetools==5.3.2
|
||||
certifi==2023.11.17
|
||||
cffi==1.16.0
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
comm==0.2.1
|
||||
contourpy==1.2.0
|
||||
croniter==1.4.1
|
||||
cycler==0.12.1
|
||||
dateutils==0.6.12
|
||||
debugpy==1.8.0
|
||||
decorator==5.1.1
|
||||
deepdiff==6.7.1
|
||||
defusedxml==0.7.1
|
||||
docker==6.1.3
|
||||
docstring-parser==0.15
|
||||
editor==1.6.5
|
||||
entrypoints==0.4
|
||||
exceptiongroup==1.2.0
|
||||
executing==2.0.1
|
||||
fastapi==0.109.0
|
||||
fastjsonschema==2.19.1
|
||||
filelock==3.13.1
|
||||
flake8==7.0.0
|
||||
fonttools==4.47.2
|
||||
fqdn==1.5.1
|
||||
frozenlist==1.4.1
|
||||
fsspec==2023.12.2
|
||||
google-auth==2.26.2
|
||||
google-auth-oauthlib==1.2.0
|
||||
grpcio==1.60.0
|
||||
h11==0.14.0
|
||||
hydra-core==1.3.2
|
||||
idna==3.6
|
||||
importlib-metadata==7.0.1
|
||||
importlib-resources==6.1.1
|
||||
iniconfig==2.0.0
|
||||
inquirer==3.2.1
|
||||
ipykernel==6.26.0
|
||||
ipython==8.17.2
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==8.1.1
|
||||
isoduration==20.11.0
|
||||
isort==5.13.2
|
||||
jedi==0.19.1
|
||||
Jinja2==3.1.3
|
||||
jmespath==1.0.1
|
||||
joblib==1.3.2
|
||||
json5==0.9.14
|
||||
jsonargparse==4.27.2
|
||||
jsonpointer==2.4
|
||||
jsonschema==4.20.0
|
||||
jsonschema-specifications==2023.12.1
|
||||
jupyter-events==0.9.0
|
||||
jupyter-ydoc==0.2.5
|
||||
jupyter_client==7.4.9
|
||||
jupyter_core==5.7.1
|
||||
jupyter_server==2.12.4
|
||||
jupyter_server_fileid==0.9.1
|
||||
jupyter_server_terminals==0.5.1
|
||||
jupyter_server_ydoc==0.6.1
|
||||
jupyterlab==3.6.1
|
||||
jupyterlab-widgets==3.0.9
|
||||
jupyterlab_pygments==0.3.0
|
||||
jupyterlab_server==2.25.2
|
||||
kiwisolver==1.4.5
|
||||
lightning==2.1.3
|
||||
lightning-api-access==0.0.5
|
||||
lightning-cloud==0.5.57
|
||||
lightning-fabric==2.1.3
|
||||
lightning-utilities==0.10.0
|
||||
lightning_sdk==0.0.13a0
|
||||
Markdown==3.5.2
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==2.1.3
|
||||
matplotlib==3.8.2
|
||||
matplotlib-inline==0.1.6
|
||||
mccabe==0.7.0
|
||||
mdurl==0.1.2
|
||||
mistune==3.0.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
mypy-extensions==1.0.0
|
||||
nbclassic==1.0.0
|
||||
nbclient==0.9.0
|
||||
nbconvert==7.14.1
|
||||
nbformat==5.9.2
|
||||
nest-asyncio==1.5.8
|
||||
networkx==3.2.1
|
||||
notebook==6.5.6
|
||||
notebook_shim==0.2.3
|
||||
numpy==1.26.2
|
||||
oauthlib==3.2.2
|
||||
omegaconf==2.3.0
|
||||
ordered-set==4.1.0
|
||||
overrides==7.4.0
|
||||
packaging==23.2
|
||||
pandas==2.1.4
|
||||
pandocfilters==1.5.0
|
||||
parso==0.8.3
|
||||
pathspec==0.12.1
|
||||
pexpect==4.9.0
|
||||
pillow==10.2.0
|
||||
platformdirs==4.1.0
|
||||
pluggy==1.4.0
|
||||
prometheus-client==0.19.0
|
||||
prompt-toolkit==3.0.43
|
||||
protobuf==4.23.4
|
||||
psutil==5.9.7
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pyasn1==0.5.1
|
||||
pyasn1-modules==0.3.0
|
||||
pycodestyle==2.11.1
|
||||
pycparser==2.21
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
pyflakes==3.2.0
|
||||
Pygments==2.17.2
|
||||
PyJWT==2.8.0
|
||||
pyparsing==3.1.1
|
||||
pytest==7.4.4
|
||||
python-dateutil==2.8.2
|
||||
python-json-logger==2.0.7
|
||||
python-multipart==0.0.6
|
||||
pytorch-lightning==2.1.2
|
||||
pytz==2023.3.post1
|
||||
PyYAML==6.0.1
|
||||
pyzmq==24.0.1
|
||||
readchar==4.0.5
|
||||
redis==5.0.1
|
||||
referencing==0.32.1
|
||||
requests==2.31.0
|
||||
requests-oauthlib==1.3.1
|
||||
rfc3339-validator==0.1.4
|
||||
rfc3986-validator==0.1.1
|
||||
rich==13.7.0
|
||||
rpds-py==0.17.1
|
||||
rsa==4.9
|
||||
runs==1.2.0
|
||||
s3fs==2023.12.2
|
||||
s3transfer==0.10.0
|
||||
scikit-learn==1.3.2
|
||||
scipy==1.11.4
|
||||
Send2Trash==1.8.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.0
|
||||
soupsieve==2.5
|
||||
stack-data==0.6.3
|
||||
starlette==0.35.1
|
||||
sympy==1.12
|
||||
tensorboard==2.15.1
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorboardX==2.6.2.2
|
||||
terminado==0.18.0
|
||||
threadpoolctl==3.2.0
|
||||
tinycss2==1.2.1
|
||||
tomli==2.0.1
|
||||
torch==2.1.1+cu118
|
||||
torchmetrics==1.3.0.post0
|
||||
tornado==6.4
|
||||
tqdm==4.66.1
|
||||
traitlets==5.14.1
|
||||
triton==2.1.0
|
||||
types-python-dateutil==2.8.19.20240106
|
||||
typeshed-client==2.4.0
|
||||
typing_extensions==4.9.0
|
||||
tzdata==2023.4
|
||||
uri-template==1.3.0
|
||||
urllib3==2.0.7
|
||||
uvicorn==0.25.0
|
||||
watermark==2.4.3
|
||||
wcwidth==0.2.13
|
||||
webcolors==1.13
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.7.0
|
||||
websockets==11.0.3
|
||||
Werkzeug==3.0.1
|
||||
widgetsnbextension==4.0.9
|
||||
wrapt==1.16.0
|
||||
xmod==1.8.1
|
||||
y-py==0.6.2
|
||||
yarl==1.9.4
|
||||
ypy-websocket==0.8.4
|
||||
zipp==3.17.0
|
64
scripts/place.sh
Normal file
64
scripts/place.sh
Normal file
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Constants
|
||||
# SEED=20230920
|
||||
SEED=0
|
||||
# INPUT_FILE="${DIR}/arrangement_$SEED.txt"
|
||||
DIR=/teamspace/studios/this_studio/out_sortcolors
|
||||
INPUT_FILE="${DIR}/arrangement_grid.txt"
|
||||
# OUTPUT_IMAGE="${DIR}/circle_composite_$SEED.png"
|
||||
TYPE=circle
|
||||
OUTPUT_IMAGE="${DIR}/${TYPE}_composite_grid.png"
|
||||
DPI=150
|
||||
CANVAS_SIZE=$((72*$DPI)) # 72 inches
|
||||
CIRCLE_IMAGE="${DIR}/hsv/sorted_colors_circle.png"
|
||||
KIND=umap
|
||||
|
||||
identify $CIRCLE_IMAGE
|
||||
|
||||
# PREP
|
||||
echo "Building ops"
|
||||
# Build the composite operations string
|
||||
composite_ops=""
|
||||
idx=1
|
||||
while IFS=, read -r x y; do
|
||||
# Translate so that (0,0) becomes the center of the canvas
|
||||
fx=$(echo "$x*$DPI + $CANVAS_SIZE/2" | bc -l)
|
||||
fy=$(echo "$CANVAS_SIZE/2 - $y*$DPI" | bc -l)
|
||||
|
||||
# Convert the final float values to integers
|
||||
ix=$(printf "%.0f" "$fx")
|
||||
iy=$(printf "%.0f" "$fy")
|
||||
if [[ idx -eq 42 ]]; then
|
||||
CIRCLE_IMAGE="${DIR}/hsv/sorted_colors_${TYPE}.png"
|
||||
else
|
||||
idx_str=$(printf "%04d" "$idx")
|
||||
CIRCLE_IMAGE="${DIR}/${KIND}/${idx_str}_sorted_colors_${TYPE}.png"
|
||||
# CIRCLE_IMAGE="${DIR}/hsv_sorted_colors_${TYPE}.png"
|
||||
fi
|
||||
|
||||
# Add to the composite operations string
|
||||
composite_ops="$composite_ops \( $CIRCLE_IMAGE \) -resize 50% -compose Over -geometry +$ix+$iy -composite"
|
||||
idx=$((idx+1))
|
||||
|
||||
done < $INPUT_FILE
|
||||
|
||||
# COMPOSITE
|
||||
echo "Compositing"
|
||||
# Use convert with the built composite operations string
|
||||
eval "convert -units PixelsPerInch \
|
||||
-size ${CANVAS_SIZE}x${CANVAS_SIZE} xc:white \
|
||||
$composite_ops \
|
||||
-density $DPI \
|
||||
$OUTPUT_IMAGE"
|
||||
|
||||
echo "Saved $OUTPUT_IMAGE"
|
||||
|
||||
# DEBUG
|
||||
# eval "convert -units PixelsPerInch \
|
||||
# $OUTPUT_IMAGE \
|
||||
# \( ${DIR}/arrangement_$SEED.png -evaluate Multiply 0.5 \) \
|
||||
# -gravity center -composite \
|
||||
# -density $DPI \
|
||||
# /tmp/debug.png"
|
||||
# open /tmp/debug.png
|
5
scripts/requirements.cuml.sh
Normal file
5
scripts/requirements.cuml.sh
Normal file
@ -0,0 +1,5 @@
|
||||
pip install \
|
||||
--extra-index-url=https://pypi.nvidia.com \
|
||||
cudf-cu12==23.12.* cuml-cu12==23.12.* \
|
||||
dask-cudf-cu12==23.12.* pylibraft-cu12==23.12.* \
|
||||
raft-dask-cu12==23.12.*
|
6
scripts/requirements.rapids.sh
Normal file
6
scripts/requirements.rapids.sh
Normal file
@ -0,0 +1,6 @@
|
||||
pip install \
|
||||
--extra-index-url=https://pypi.nvidia.com \
|
||||
cudf-cu12==23.12.* dask-cudf-cu12==23.12.* cuml-cu12==23.12.* \
|
||||
cugraph-cu12==23.12.* cuspatial-cu12==23.12.* cuproj-cu12==23.12.* \
|
||||
cuxfilter-cu12==23.12.* cucim-cu12==23.12.* pylibraft-cu12==23.12.* \
|
||||
raft-dask-cu12==23.12.*
|
148
scripts/scatter.py
Normal file
148
scripts/scatter.py
Normal file
@ -0,0 +1,148 @@
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_quadrant(x, y):
|
||||
"""Return the quadrant of a given point."""
|
||||
if x >= 0 and y >= 0:
|
||||
return 1
|
||||
elif x < 0 and y >= 0:
|
||||
return 2
|
||||
elif x < 0 and y < 0:
|
||||
return 3
|
||||
else:
|
||||
return 4
|
||||
|
||||
|
||||
def is_overlapping(x, y, existing_points, radius):
|
||||
"""Check if a point overlaps with existing points, crosses a quadrant or goes beyond the quadrant lines."""
|
||||
current_quadrant = get_quadrant(x, y)
|
||||
buffer = 0.25
|
||||
# Check if circle touches/crosses the x=0 or y=0 lines.
|
||||
if (
|
||||
abs(x) + radius > buffer
|
||||
and abs(x) - radius < buffer
|
||||
or abs(y) + radius > buffer
|
||||
and abs(y) - radius < buffer
|
||||
):
|
||||
return True
|
||||
|
||||
for ex, ey in existing_points:
|
||||
# Check overlap with existing points
|
||||
if np.sqrt((x - ex) ** 2 + (y - ey) ** 2) <= 2 * radius + 0.5:
|
||||
return True
|
||||
# Check if the circle touches another quadrant
|
||||
if (
|
||||
get_quadrant(ex, ey) != current_quadrant
|
||||
and np.sqrt((x - ex) ** 2 + (y - ey) ** 2) <= radius
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DIR = "/teamspace/studios/this_studio/out_sortcolors"
|
||||
|
||||
N = 100
|
||||
DPI = 300
|
||||
SIZE = 72 # canvas size?
|
||||
radius = 3
|
||||
variance = 72 # Adjust variance as needed
|
||||
|
||||
Path(DIR).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
with open(f"{DIR}/arrangement_grid.txt", "w") as f:
|
||||
radius = 3
|
||||
half = SIZE / 2 - 4
|
||||
# make points an equispaced grid of 10 x 10 ranging from [-35, 35]
|
||||
interval = (half - (-half)) / 9 # 10 points, so 9 intervals
|
||||
|
||||
# Generate the grid points
|
||||
points = [
|
||||
(x, y)
|
||||
for x in np.arange(-half, half + 0.1, interval)
|
||||
for y in np.arange(-half, half + 0.1, interval)
|
||||
]
|
||||
|
||||
for x, y in points:
|
||||
f.write(f"{x-radius}, {y+radius}\n")
|
||||
# f.write(f"{x}, {y}\n")
|
||||
|
||||
print("wrote grid")
|
||||
|
||||
|
||||
for seed in range(11, 22):
|
||||
try:
|
||||
np.random.seed(seed)
|
||||
|
||||
# To store plotted points
|
||||
points = []
|
||||
max_iterations = int(1e7)
|
||||
iterations = 0
|
||||
# Generate points
|
||||
for k in range(N):
|
||||
while True:
|
||||
# x, y = np.random.normal(0, variance), np.random.normal(0, variance)
|
||||
random_angle = np.random.uniform(0, 2 * np.pi)
|
||||
# random_radius = np.random.uniform(0.25+radius, SIZE/2 - radius - 0.25)
|
||||
random_radius = abs(np.random.normal(0, variance))
|
||||
x = random_radius * np.cos(random_angle)
|
||||
y = random_radius * np.sin(random_angle)
|
||||
iterations += 1
|
||||
if not is_overlapping(x, y, points, radius):
|
||||
if max(abs(x), abs(y)) + radius < SIZE / 2 - 0.25:
|
||||
points.append((x, y))
|
||||
break
|
||||
if iterations > max_iterations:
|
||||
raise ValueError(f"Too many iterations: {k} points")
|
||||
print(f"{k}: ({x}, {y}) @ {iterations:09d}")
|
||||
|
||||
# Create plot with circles
|
||||
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
|
||||
for x, y in points:
|
||||
circle = patches.Circle((x, y), radius, color="black")
|
||||
ax.add_patch(circle)
|
||||
|
||||
# Draw the standard quadrants
|
||||
ax.axhline(0, color="black", linewidth=0.5)
|
||||
ax.axvline(0, color="black", linewidth=0.5)
|
||||
|
||||
# Use square axis and set limits
|
||||
|
||||
lim_x = (
|
||||
max(
|
||||
abs(max(points, key=lambda t: t[0])[0]),
|
||||
abs(min(points, key=lambda t: t[0])[0]),
|
||||
)
|
||||
+ radius
|
||||
)
|
||||
lim_y = (
|
||||
max(
|
||||
abs(max(points, key=lambda t: t[1])[1]),
|
||||
abs(min(points, key=lambda t: t[1])[1]),
|
||||
)
|
||||
+ radius
|
||||
)
|
||||
lim_x = lim_y = SIZE / 2
|
||||
ax.axis("off")
|
||||
ax.set_aspect("equal")
|
||||
ax.set_xlim(-lim_x, lim_x)
|
||||
ax.set_ylim(-lim_y, lim_y)
|
||||
|
||||
# Save and show
|
||||
fig.tight_layout(pad=0)
|
||||
plt.savefig(
|
||||
f"{DIR}/arrangement_{seed}.png", dpi=DPI, bbox_inches="tight", pad_inches=0
|
||||
)
|
||||
# plt.show()
|
||||
# also save x/y coords as text file
|
||||
with open(f"{DIR}/arrangement_{seed}.txt", "w") as f:
|
||||
for x, y in points:
|
||||
f.write(f"{x-radius}, {y+radius}\n")
|
||||
# f.write(f"{x}, {y}\n")
|
||||
except ValueError as e:
|
||||
print(f"{seed}: {e}")
|
||||
except AssertionError as e:
|
||||
print(f"{seed}: {e}")
|
379
scripts/sortcolor.py
Normal file
379
scripts/sortcolor.py
Normal file
@ -0,0 +1,379 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from hilbertcurve.hilbertcurve import HilbertCurve
|
||||
|
||||
# Extract XKCD colors
|
||||
colors = list(mcolors.XKCD_COLORS.keys())
|
||||
rgb_values = [mcolors.to_rgb(mcolors.XKCD_COLORS[color]) for color in colors]
|
||||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--sort-by", type=str, default="hsv", help="kind of sorting")
|
||||
parser.add_argument("--seed", type=int, default=21, help="seed for UMAP")
|
||||
parser.add_argument("--dpi", type=int, default=100, help="dpi for saving")
|
||||
parser.add_argument("--size", type=float, default=6.0, help="size of figure")
|
||||
parser.add_argument(
|
||||
"--fontsize",
|
||||
type=float,
|
||||
default=0,
|
||||
help="fontsize of annotation (default: 0 = None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--radius", type=float, default=1 / 3, help="inner radius of circle"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
KIND = args.sort_by
|
||||
SEED = args.seed
|
||||
DPI = args.dpi
|
||||
SIZE = args.size
|
||||
FONTSIZE = args.fontsize
|
||||
INNER_RADIUS = args.radius
|
||||
DIR = "/teamspace/studios/this_studio/out_sortcolors"
|
||||
|
||||
|
||||
prefix = ""
|
||||
if KIND == "umap":
|
||||
prefix = f"{SEED:04d}_"
|
||||
FDIR = f"{DIR}/{KIND}"
|
||||
Path(FDIR).mkdir(exist_ok=True, parents=True)
|
||||
fname = f"{FDIR}/{prefix}sorted_colors_circle.png"
|
||||
|
||||
|
||||
def peano_curve(n):
|
||||
"""
|
||||
Generate Peano curve coordinates for a given order `n`.
|
||||
"""
|
||||
if n == 0:
|
||||
return [(0, 0)]
|
||||
|
||||
prev_curve = peano_curve(n - 1)
|
||||
max_coord = 3 ** (n - 1)
|
||||
|
||||
# Define the transformations for the Peano curve's 9 segments
|
||||
transforms = [
|
||||
lambda x, y: (x, y), # Bottom-left square
|
||||
lambda x, y: (x + max_coord, y), # Bottom-middle square
|
||||
lambda x, y: (x + 2 * max_coord, y), # Bottom-right square
|
||||
lambda x, y: (x + 2 * max_coord, y + max_coord), # Middle-right square
|
||||
lambda x, y: (
|
||||
x + max_coord,
|
||||
y + max_coord,
|
||||
), # Center square (traversed in reverse)
|
||||
lambda x, y: (x, y + max_coord), # Middle-left square
|
||||
lambda x, y: (x, y + 2 * max_coord), # Top-left square
|
||||
lambda x, y: (x + max_coord, y + 2 * max_coord), # Top-middle square
|
||||
lambda x, y: (x + 2 * max_coord, y + 2 * max_coord), # Top-right square
|
||||
]
|
||||
|
||||
curve = []
|
||||
for transform in transforms:
|
||||
segment = [transform(x, y) for x, y in prev_curve]
|
||||
# Reverse the traversal for the center square
|
||||
if transform == transforms[4]:
|
||||
segment = segment[::-1]
|
||||
curve += segment
|
||||
|
||||
return curve
|
||||
|
||||
|
||||
if KIND in ("lex", "alpha", "abc"):
|
||||
preds = np.array(colors)
|
||||
|
||||
elif KIND == "umap":
|
||||
# from umap import UMAP
|
||||
from cuml import UMAP
|
||||
|
||||
# Use UMAP to create a 1D representation
|
||||
reducer = UMAP(
|
||||
n_components=1,
|
||||
n_neighbors=250,
|
||||
min_dist=1e-2,
|
||||
metric="euclidean",
|
||||
random_state=SEED,
|
||||
negative_sample_rate=2,
|
||||
)
|
||||
embedding = reducer.fit_transform(np.array(rgb_values))
|
||||
|
||||
# Sort colors by the 1D representation
|
||||
preds = embedding[:, 0]
|
||||
del reducer, embedding
|
||||
|
||||
elif KIND in ("cielab", "lab", "ciede2000"):
|
||||
from skimage.color import deltaE_ciede2000, rgb2lab
|
||||
|
||||
# CIELAB
|
||||
# Convert RGB values to CIELAB
|
||||
lab_values = rgb2lab([rgb_values])
|
||||
|
||||
# Reference color for sorting (can be the first color or any other reference point)
|
||||
reference_color = lab_values[0]
|
||||
|
||||
# Compute CIEDE2000 distances of all colors to the reference color
|
||||
distances = [deltaE_ciede2000(reference_color, color) for color in lab_values]
|
||||
|
||||
# Sort colors by their CIEDE2000 distance to the reference color
|
||||
# preds = distances).flatten() # awful
|
||||
lab_values_flat = lab_values.reshape(-1, 3)
|
||||
# Sort colors based on the L* value in the CIELAB space
|
||||
# 0 corresponds to the L* channel
|
||||
preds = lab_values_flat[:, 0]
|
||||
|
||||
elif KIND == "hsv":
|
||||
from matplotlib.colors import rgb_to_hsv
|
||||
|
||||
# Convert RGB values to HSV
|
||||
hsv_values = np.array([rgb_to_hsv(np.array(rgb)) for rgb in rgb_values])
|
||||
|
||||
# Sort colors based on the hue value
|
||||
# 0 corresponds to the hue component
|
||||
preds = hsv_values[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown kind: {KIND}")
|
||||
|
||||
sorted_indices = np.argsort(preds)
|
||||
|
||||
# Save the sorted indices to disk
|
||||
# if (KIND == "umap") or (KIND != "umap"):
|
||||
PDIR = f"scripts/{KIND}"
|
||||
Path(PDIR).mkdir(parents=True, exist_ok=True)
|
||||
file_path = f"{PDIR}/{SEED:06d}.npy"
|
||||
np.save(file_path, preds.ravel())
|
||||
print(f"Predictions saved to {file_path}")
|
||||
|
||||
# Sort colors by the 1D representation
|
||||
sorted_colors = [colors[i] for i in sorted_indices]
|
||||
|
||||
# # Display the sorted colors around the ring of a circle
|
||||
# # Create a new figure for the circle visualization
|
||||
# fig, ax = plt.subplots(figsize=(SIZE, SIZE))
|
||||
|
||||
# # Circle parameters
|
||||
# center = (0, 0)
|
||||
# radius = 1.0
|
||||
|
||||
# # Angle increment (in radians) based on the number of colors
|
||||
# angle_increment = 2 * np.pi / len(sorted_colors)
|
||||
# # roll the colors so that the first color is white
|
||||
# reordered_sorted_colors = sorted_colors.copy()
|
||||
# white_index = reordered_sorted_colors.index("xkcd:white")
|
||||
# reordered_sorted_colors = reordered_sorted_colors[white_index:] + reordered_sorted_colors[:white_index]
|
||||
# # Plot each color around the circle
|
||||
# for i, color in enumerate(reordered_sorted_colors):
|
||||
# # Compute start and end angles for the segment
|
||||
# start_angle = i * angle_increment
|
||||
# end_angle = (i + 1) * angle_increment
|
||||
|
||||
# # Create a wedge (segment of the circle) for each color
|
||||
# wedge = patches.Wedge(
|
||||
# center, radius, 90 + np.degrees(start_angle), 90 + np.degrees(end_angle), fc=color
|
||||
# )
|
||||
|
||||
# ax.add_patch(wedge)
|
||||
|
||||
# # Overlay a white circle in the center
|
||||
# inner_circle = patches.Circle(center, INNER_RADIUS, color="white")
|
||||
# ax.add_patch(inner_circle)
|
||||
# if INNER_RADIUS > 0.0:
|
||||
# fcolor = "black"
|
||||
# else:
|
||||
# fcolor = "white"
|
||||
|
||||
# if FONTSIZE > 0.0:
|
||||
# ax.annotate(f"{KIND.upper()}", center, ha="center", va="center", size=FONTSIZE, color=fcolor)
|
||||
|
||||
# # Set equal scaling and remove axis
|
||||
# ax.set_aspect("equal")
|
||||
# ax.axis("off")
|
||||
# ax.set_ylim(-radius, radius)
|
||||
# ax.set_xlim(-radius, radius)
|
||||
|
||||
# # Save and display the circle visualization
|
||||
# prefix = ""
|
||||
# if KIND == "umap":
|
||||
# prefix = f"{SEED:02d}"
|
||||
|
||||
# fname = f"{DIR}/{prefix}{KIND}_sorted_colors_circle.png"
|
||||
# fig.tight_layout(pad=0)
|
||||
# fig.savefig(
|
||||
# fname,
|
||||
# dpi=DPI,
|
||||
# transparent=True,
|
||||
# pad_inches=0,
|
||||
# bbox_inches="tight",
|
||||
# )
|
||||
# print(f"Saved {fname}")
|
||||
# # plt.show()
|
||||
|
||||
|
||||
def plot_preds(
|
||||
preds,
|
||||
rgb_values,
|
||||
fname: str,
|
||||
roll: bool = False,
|
||||
dpi: int = 150,
|
||||
inner_radius: float = 1 / 3,
|
||||
figsize=(3, 3),
|
||||
):
|
||||
sorted_inds = np.argsort(preds.ravel())
|
||||
colors = rgb_values[sorted_inds, :3]
|
||||
if roll:
|
||||
# find white in colors, put it first.
|
||||
white = np.array([1, 1, 1])
|
||||
white_idx = np.where((colors == white).all(axis=1))
|
||||
if white_idx:
|
||||
white_idx = white_idx[0][0]
|
||||
colors = np.roll(colors, -white_idx, axis=0)
|
||||
else:
|
||||
print("no white, skipping")
|
||||
# print(white_idx, colors[:2])
|
||||
|
||||
N = len(colors)
|
||||
# Create a plot with these hues in a circle
|
||||
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
||||
|
||||
# Each wedge in the circle
|
||||
theta = np.linspace(0, 2 * np.pi, N, endpoint=False) + np.pi / 2
|
||||
width = 2 * np.pi / (N) # equal size for each wedge
|
||||
|
||||
for i in range(N):
|
||||
ax.bar(
|
||||
# 2 * np.pi * preds[i],
|
||||
theta[i],
|
||||
height=1,
|
||||
width=width,
|
||||
edgecolor=colors[i],
|
||||
linewidth=0.25,
|
||||
# facecolor=[rgb_values[i][1]]*3,
|
||||
# facecolor=rgb_values[i],
|
||||
facecolor=colors[i],
|
||||
bottom=0.0,
|
||||
zorder=1,
|
||||
alpha=1,
|
||||
align="edge",
|
||||
)
|
||||
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
radius = 1
|
||||
ax.set_ylim(0, radius)
|
||||
|
||||
# Overlay white circle
|
||||
circle = patches.Circle(
|
||||
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
|
||||
if FONTSIZE > 0.0:
|
||||
center = (0, 0)
|
||||
fcolor = "black"
|
||||
ax.annotate(
|
||||
f"{KIND.upper()}",
|
||||
center,
|
||||
ha="center",
|
||||
va="center",
|
||||
size=FONTSIZE,
|
||||
color=fcolor,
|
||||
)
|
||||
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.savefig(fname, dpi=dpi, transparent=True, pad_inches=0, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
plot_preds(
|
||||
preds,
|
||||
np.array(rgb_values),
|
||||
fname,
|
||||
roll=False,
|
||||
dpi=DPI,
|
||||
inner_radius=INNER_RADIUS,
|
||||
figsize=(SIZE, SIZE),
|
||||
)
|
||||
print(f"saved {fname}")
|
||||
|
||||
HILBERT = False
|
||||
|
||||
if HILBERT:
|
||||
# Create Hilbert curve
|
||||
# We'll set the order such that the number of positions is greater than or equal to the number of colors
|
||||
hilbert_order = int(np.ceil(0.5 * np.log2(len(sorted_colors))))
|
||||
hilbert_curve = HilbertCurve(hilbert_order, 2)
|
||||
|
||||
# Create an image for visualization
|
||||
image_size = 2**hilbert_order
|
||||
image = np.ones((image_size, image_size, 3))
|
||||
|
||||
for i, color in enumerate(sorted_colors):
|
||||
# Convert linear index to Hilbert coordinates
|
||||
coords = hilbert_curve.point_from_distance(i)
|
||||
image[coords[1], coords[0]] = mcolors.to_rgb(color)
|
||||
|
||||
# annotation in upper right
|
||||
# Display the image
|
||||
fig, ax = plt.subplots(1, 1, figsize=(SIZE, SIZE))
|
||||
ax.imshow(image)
|
||||
ax.annotate(
|
||||
f"{KIND.upper()}",
|
||||
(1.0, 1.0),
|
||||
ha="right",
|
||||
va="top",
|
||||
size=FONTSIZE,
|
||||
xycoords="axes fraction",
|
||||
)
|
||||
ax.axis("off")
|
||||
ax.set_aspect("equal")
|
||||
fig.tight_layout()
|
||||
fname = f"{DIR}/{prefix}{KIND}_sorted_colors_hilbert.png"
|
||||
fig.savefig(
|
||||
fname,
|
||||
dpi=DPI,
|
||||
transparent=True,
|
||||
# bbox_inches="tight",
|
||||
# pad_inches=0
|
||||
)
|
||||
print(f"Saved {fname}")
|
||||
|
||||
# plt.show()
|
||||
|
||||
|
||||
# # Create peano curve
|
||||
# order = 0
|
||||
# while (3**order)**2 < len(sorted_colors):
|
||||
# order += 1
|
||||
|
||||
# # Generate peano curve coordinates for the determined order
|
||||
# curve_coords = peano_curve(order)
|
||||
# unique_coords = set(curve_coords)
|
||||
# print(f"Total coordinates: {len(curve_coords)}")
|
||||
# print(f"Unique coordinates: {len(unique_coords)}")
|
||||
|
||||
# # If there are more points on the curve than colors, truncate the curve
|
||||
# if len(curve_coords) > len(sorted_colors):
|
||||
# print(f"Will be missing {len(curve_coords) - len(sorted_colors)} pixels, percentage: {100 * (len(curve_coords) - len(sorted_colors)) / len(curve_coords)}")
|
||||
# curve_coords = curve_coords[:len(sorted_colors)]
|
||||
# if len(curve_coords) < len(sorted_colors):
|
||||
# raise ValueError("Not enough curve coordinates for the number of colors")
|
||||
|
||||
# # Create an image for visualization
|
||||
# image_size = (3**order)
|
||||
# image = np.ones((image_size, image_size, 3))
|
||||
|
||||
# for i, color in enumerate(sorted_colors):
|
||||
# # Get the Moore curve coordinates for the current index
|
||||
# coords = curve_coords[i]
|
||||
# image[coords[1], coords[0]] = mcolors.to_rgb(color)
|
||||
|
||||
# # Display the image
|
||||
# plt.figure(figsize=(8, 8))
|
||||
# plt.imshow(image)
|
||||
# plt.axis("off")
|
||||
# plt.savefig(f"{DIR}/{KIND}_sorted_colors_moore.png", dpi=DPI)
|
||||
# plt.show()
|
67
scripts/vips_composite.py
Normal file
67
scripts/vips_composite.py
Normal file
@ -0,0 +1,67 @@
|
||||
import pyvips
|
||||
import glob
|
||||
import random
|
||||
import os
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def bg_fill(width: int, height: int, channels: int = 3, fill: int = 255):
|
||||
a = np.ones(shape=(height, width, channels)) * fill
|
||||
return pyvips.Image.new_from_array(a)
|
||||
|
||||
|
||||
def create_grid_composite(
|
||||
directory, k, spacing_inches, output_file="output.png", dpi=300
|
||||
):
|
||||
# Calculate spacing in pixels
|
||||
spacing_pixels = int(spacing_inches * dpi)
|
||||
|
||||
# Glob for PNG images
|
||||
png_files = glob.glob(os.path.join(directory, "*.png"))
|
||||
|
||||
# Randomly select K^2 images
|
||||
selected_files = random.sample(png_files, k * k)
|
||||
|
||||
# Create an empty list to hold the images
|
||||
images = [
|
||||
pyvips.Image.new_from_file(file, access="sequential") for file in selected_files
|
||||
]
|
||||
|
||||
# Calculate the size of the composite image
|
||||
# widths, heights = zip(*[image.size for image in images])
|
||||
widths, heights = [1800], [1800]
|
||||
max_width = max(widths)
|
||||
max_height = max(heights)
|
||||
|
||||
# Calculate total size of the grid including spacing
|
||||
total_width = k * max_width + (k + 1) * spacing_pixels
|
||||
total_height = k * max_height + (k + 1) * spacing_pixels
|
||||
|
||||
# Create a blank image for the composite
|
||||
# composite = pyvips.Image.black(total_width, total_height, bands=4)
|
||||
composite = bg_fill(total_width, total_height, channels=1, fill=255)
|
||||
|
||||
# Place images into the composite
|
||||
for i, image in enumerate(images):
|
||||
row = i // k
|
||||
col = i % k
|
||||
x = col * (max_width + spacing_pixels) + spacing_pixels
|
||||
y = row * (max_height + spacing_pixels) + spacing_pixels
|
||||
|
||||
composite = composite.insert(image.flatten(background=[255, 255, 255]), x, y)
|
||||
|
||||
# Save the composite image
|
||||
composite.write_to_file(output_file)
|
||||
x = pyvips.Image.thumbnail(output_file, 1080 * 4)
|
||||
x.write_to_file("out_sm.png")
|
||||
print(output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
directory = "/teamspace/studios/this_studio/out_sortcolors/umap/" # Change to your directory path
|
||||
k = 10
|
||||
spacing_inches = 0.5 # Half an inch between images
|
||||
create_grid_composite(directory, k, spacing_inches)
|
28
search.py
28
search.py
@ -2,37 +2,41 @@ import subprocess
|
||||
import sys
|
||||
from random import sample
|
||||
|
||||
import numpy as np
|
||||
import numpy as np # noqa: F401
|
||||
from lightning_sdk import Machine, Studio # noqa: F401
|
||||
|
||||
NUM_JOBS = 100
|
||||
|
||||
# reference to the current studio
|
||||
# if you run outside of Lightning, you can pass the Studio name
|
||||
studio = Studio()
|
||||
# studio = Studio()
|
||||
|
||||
# use the jobs plugin
|
||||
studio.install_plugin("jobs")
|
||||
job_plugin = studio.installed_plugins["jobs"]
|
||||
# studio.install_plugin("jobs")
|
||||
# job_plugin = studio.installed_plugins["jobs"]
|
||||
|
||||
# do a sweep over learning rates
|
||||
|
||||
# Define the ranges or sets of values for each hyperparameter
|
||||
alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
|
||||
# learning_rate_values = list(np.round(np.logspace(-5, -3, 41), 5))
|
||||
learning_rate_values = [5e-4]
|
||||
batch_size_values = [128]
|
||||
max_epochs_values = [500]
|
||||
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
|
||||
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
|
||||
learning_rate_values = [1e-2]
|
||||
alpha_values = [0, 1, 2]
|
||||
widths = [2**k for k in range(4, 15)]
|
||||
# learning_rate_values = [5e-4]
|
||||
batch_size_values = [256]
|
||||
max_epochs_values = [100]
|
||||
seeds = list(range(21, 1992))
|
||||
|
||||
# Generate all possible combinations of hyperparameters
|
||||
all_params = [
|
||||
(alpha, lr, bs, me, s)
|
||||
(alpha, lr, bs, me, s, w)
|
||||
for alpha in alpha_values
|
||||
for lr in learning_rate_values
|
||||
for bs in batch_size_values
|
||||
for me in max_epochs_values
|
||||
for s in seeds
|
||||
for w in widths
|
||||
]
|
||||
|
||||
|
||||
@ -40,8 +44,8 @@ all_params = [
|
||||
search_params = sample(all_params, min(NUM_JOBS, len(all_params)))
|
||||
|
||||
for idx, params in enumerate(search_params):
|
||||
a, lr, bs, me, s = params
|
||||
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s}"
|
||||
a, lr, bs, me, s, w = params
|
||||
cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me} --seed {s} --width {w}"
|
||||
# job_name = f"color2_{bs}_{a}_{lr:2.2e}"
|
||||
# job_plugin.run(cmd, machine=Machine.T4, name=job_name)
|
||||
print(f"Running {params}: {cmd}")
|
||||
|
9
utils.py
9
utils.py
@ -2,7 +2,7 @@ import matplotlib.colors as mcolors
|
||||
import torch
|
||||
|
||||
|
||||
def preprocess_data(data, skip=True):
|
||||
def preprocess_data(data, skip: bool = True):
|
||||
# Assuming 'data' is a tensor of shape [n_samples, 3]
|
||||
if not skip:
|
||||
# Compute argmin and argmax for each row
|
||||
@ -34,6 +34,9 @@ def extract_colors():
|
||||
return rgb_tensor, xkcd_color_names
|
||||
|
||||
|
||||
PURE_RGB = preprocess_data(
|
||||
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
|
||||
RGBMYC_ANCHOR = preprocess_data(
|
||||
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
|
||||
)
|
||||
PURE_HSV = torch.tensor(
|
||||
[[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user