Browse Source

another attempt

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
70ecd7d7db
  1. 4
      check.py
  2. 14
      dataloader.py
  3. 18
      losses.py
  4. 25
      main.py
  5. 2
      makefile
  6. 46
      model.py
  7. 8
      search.py
  8. 3
      utils.py

4
check.py

@ -39,7 +39,7 @@ def make_image(ckpt: str, fname: str, color=True):
plt.savefig(f"{fname}.png", dpi=300)
def create_circle(ckpt: str, fname: str, dpi: int = 150):
def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
@ -57,7 +57,7 @@ def plot_preds(
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :]
colors = rgb_values[sorted_inds, :3]
if roll:
# find white in colors, put it first.
white = np.array([1, 1, 1])

14
dataloader.py

@ -4,32 +4,32 @@ from torch.utils.data import DataLoader, TensorDataset
from utils import extract_colors, preprocess_data
def create_dataloader(N: int = 1e8, **kwargs):
def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
rgb_tensor = preprocess_data(rgb_tensor)
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader
def create_gray_supplement(N: int = 50):
def create_gray_supplement(N: int = 50, skip: bool = True):
linear_space = torch.linspace(0, 1, N)
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
gray_tensor = preprocess_data(gray_tensor)
gray_tensor = preprocess_data(gray_tensor, skip=skip)
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
def create_named_dataloader(N: int = 0, **kwargs):
def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
if N > 0:
dataset_with_names += create_gray_supplement(N)
dataset_with_names += create_gray_supplement(N, skip=skip)
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names

18
losses.py

@ -17,24 +17,30 @@ from utils import PURE_RGB
# return smoothness_loss
def preservation_loss(inputs, outputs):
def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None:
target_inputs = inputs
else:
assert target_outputs is not None
if target_outputs is None:
target_outputs = outputs
# Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
# max_rgb_distance = 1
rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1))
torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
/ max_rgb_distance
)
rgb_norm = (
rgb_norm % 1
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
rgb_norm = rgb_norm % 1
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu(
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1)
torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
)
diff = torch.abs(rgb_norm - transformed_norm)

25
main.py

@ -4,10 +4,10 @@ import random
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
from callbacks import SaveImageCallback
from dataloader import create_dataloader
from dataloader import create_named_dataloader as create_dataloader
from model import ColorTransformerModel
@ -62,23 +62,24 @@ if __name__ == "__main__":
seed_everything(args.seed)
early_stop_callback = EarlyStopping(
monitor="hp_metric", # Metric to monitor
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
patience=5, # Number of epochs with no improvement after which training will be stopped
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True,
)
# early_stop_callback = EarlyStopping(
# monitor="hp_metric", # Metric to monitor
# min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
# patience=5, # Number of epochs with no improvement after which training will be stopped
# mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
# verbose=True,
# )
save_img_callback = SaveImageCallback(
save_interval=0,
final_dir=None,
final_dir="out",
)
# Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_dataloader(
N=1e5,
# N=1e5,
skip=False,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,
@ -97,7 +98,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
deterministic=True,
callbacks=[early_stop_callback, save_img_callback],
callbacks=[save_img_callback],
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)

2
makefile

@ -4,7 +4,7 @@ lint:
flake8 --ignore E501,W503 .
test:
python main.py --alpha 2 --lr 1e-3 --max_epochs 200 --bs 16384 --seed 1914
python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048
search:
python search.py

46
model.py

@ -3,7 +3,8 @@ import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss
from losses import calculate_separation_loss, preservation_loss # noqa: F401
from utils import PURE_HSV, PURE_RGB
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
@ -83,18 +84,40 @@ class ColorTransformerModel(pl.LightningModule):
def __init__(self, params):
super().__init__()
self.save_hyperparameters(params)
# self.a = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# self.b = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# Neural network layers
self.network = nn.Sequential(
nn.Linear(3, self.hparams.width),
nn.ReLU(),
nn.Linear(self.hparams.width, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Linear(5, 64),
nn.Tanh(),
nn.Linear(64, self.hparams.width),
nn.Tanh(),
nn.Linear(self.hparams.width, 3),
nn.Tanh(),
nn.Linear(3, 1),
)
def forward(self, x):
# Pass the input through the network
# a = self.a(x)
# b = self.b(x)
# a = torch.sigmoid(a)
# b = torch.sigmoid(b)
# x = torch.cat([x, a, b], dim=-1)
x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2
@ -104,7 +127,14 @@ class ColorTransformerModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
s_loss = calculate_separation_loss(model=self)
# s_loss = calculate_separation_loss(model=self)
# preserve distance to pure R, G, B. this acts kind of like labeled data.
s_loss = preservation_loss(
inputs,
outputs,
target_inputs=PURE_RGB,
target_outputs=PURE_HSV,
)
p_loss = preservation_loss(
inputs,
outputs,

8
search.py

@ -20,12 +20,12 @@ NUM_JOBS = 100
# Define the ranges or sets of values for each hyperparameter
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-2, 1e-3]
learning_rate_values = [1e-2]
alpha_values = [0, 1, 2]
widths = [64, 128, 256, 512]
widths = [2**k for k in range(4, 15)]
# learning_rate_values = [5e-4]
batch_size_values = [8192]
max_epochs_values = [50]
batch_size_values = [256]
max_epochs_values = [100]
seeds = list(range(21, 1992))
# Generate all possible combinations of hyperparameters

3
utils.py

@ -2,7 +2,7 @@ import matplotlib.colors as mcolors
import torch
def preprocess_data(data, skip=True):
def preprocess_data(data, skip: bool = False):
# Assuming 'data' is a tensor of shape [n_samples, 3]
if not skip:
# Compute argmin and argmax for each row
@ -37,3 +37,4 @@ def extract_colors():
PURE_RGB = preprocess_data(
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
)
PURE_HSV = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)

Loading…
Cancel
Save