Browse Source

another attempt

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
70ecd7d7db
  1. 4
      check.py
  2. 14
      dataloader.py
  3. 18
      losses.py
  4. 25
      main.py
  5. 2
      makefile
  6. 46
      model.py
  7. 8
      search.py
  8. 3
      utils.py

4
check.py

@ -39,7 +39,7 @@ def make_image(ckpt: str, fname: str, color=True):
plt.savefig(f"{fname}.png", dpi=300) plt.savefig(f"{fname}.png", dpi=300)
def create_circle(ckpt: str, fname: str, dpi: int = 150): def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
else: else:
@ -57,7 +57,7 @@ def plot_preds(
if isinstance(preds, torch.Tensor): if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy() preds = preds.detach().cpu().numpy()
sorted_inds = np.argsort(preds.ravel()) sorted_inds = np.argsort(preds.ravel())
colors = rgb_values[sorted_inds, :] colors = rgb_values[sorted_inds, :3]
if roll: if roll:
# find white in colors, put it first. # find white in colors, put it first.
white = np.array([1, 1, 1]) white = np.array([1, 1, 1])

14
dataloader.py

@ -4,32 +4,32 @@ from torch.utils.data import DataLoader, TensorDataset
from utils import extract_colors, preprocess_data from utils import extract_colors, preprocess_data
def create_dataloader(N: int = 1e8, **kwargs): def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32) rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset and data loader # Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
train_dataloader = DataLoader(dataset, **kwargs) train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader return train_dataloader
def create_gray_supplement(N: int = 50): def create_gray_supplement(N: int = 50, skip: bool = True):
linear_space = torch.linspace(0, 1, N) linear_space = torch.linspace(0, 1, N)
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
gray_tensor = preprocess_data(gray_tensor) gray_tensor = preprocess_data(gray_tensor, skip=skip)
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
def create_named_dataloader(N: int = 0, **kwargs): def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset with RGB values and their corresponding color names # Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [ dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor)) for i in range(len(rgb_tensor))
] ]
if N > 0: if N > 0:
dataset_with_names += create_gray_supplement(N) dataset_with_names += create_gray_supplement(N, skip=skip)
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names return train_dataloader_with_names

18
losses.py

@ -17,24 +17,30 @@ from utils import PURE_RGB
# return smoothness_loss # return smoothness_loss
def preservation_loss(inputs, outputs): def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# Distance Preservation Component # Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space
if target_inputs is None:
target_inputs = inputs
else:
assert target_outputs is not None
if target_outputs is None:
target_outputs = outputs
# Calculate RGB Norm # Calculate RGB Norm
max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1]
# max_rgb_distance = 1
rgb_norm = ( rgb_norm = (
torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)) torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1))
/ max_rgb_distance / max_rgb_distance
) )
rgb_norm = ( # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
rgb_norm % 1 rgb_norm = rgb_norm % 1
) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space
# print(rgb_norm) # print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity) # Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = torch.triu( transformed_norm = torch.triu(
torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1) torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1)
) )
diff = torch.abs(rgb_norm - transformed_norm) diff = torch.abs(rgb_norm - transformed_norm)

25
main.py

@ -4,10 +4,10 @@ import random
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping # noqa: F401
from callbacks import SaveImageCallback from callbacks import SaveImageCallback
from dataloader import create_dataloader from dataloader import create_named_dataloader as create_dataloader
from model import ColorTransformerModel from model import ColorTransformerModel
@ -62,23 +62,24 @@ if __name__ == "__main__":
seed_everything(args.seed) seed_everything(args.seed)
early_stop_callback = EarlyStopping( # early_stop_callback = EarlyStopping(
monitor="hp_metric", # Metric to monitor # monitor="hp_metric", # Metric to monitor
min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement # min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
patience=5, # Number of epochs with no improvement after which training will be stopped # patience=5, # Number of epochs with no improvement after which training will be stopped
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. # mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True, # verbose=True,
) # )
save_img_callback = SaveImageCallback( save_img_callback = SaveImageCallback(
save_interval=0, save_interval=0,
final_dir=None, final_dir="out",
) )
# Initialize data loader with parsed arguments # Initialize data loader with parsed arguments
# named_data_loader also has grayscale extras. TODO: remove unnamed # named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_dataloader( train_dataloader = create_dataloader(
N=1e5, # N=1e5,
skip=False,
batch_size=args.bs, batch_size=args.bs,
shuffle=True, shuffle=True,
num_workers=args.num_workers, num_workers=args.num_workers,
@ -97,7 +98,7 @@ if __name__ == "__main__":
# Initialize trainer with parsed arguments # Initialize trainer with parsed arguments
trainer = pl.Trainer( trainer = pl.Trainer(
deterministic=True, deterministic=True,
callbacks=[early_stop_callback, save_img_callback], callbacks=[save_img_callback],
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps, log_every_n_steps=args.log_every_n_steps,
) )

2
makefile

@ -4,7 +4,7 @@ lint:
flake8 --ignore E501,W503 . flake8 --ignore E501,W503 .
test: test:
python main.py --alpha 2 --lr 1e-3 --max_epochs 200 --bs 16384 --seed 1914 python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048
search: search:
python search.py python search.py

46
model.py

@ -3,7 +3,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss from losses import calculate_separation_loss, preservation_loss # noqa: F401
from utils import PURE_HSV, PURE_RGB
# class ColorTransformerModel(pl.LightningModule): # class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params): # def __init__(self, params):
@ -83,18 +84,40 @@ class ColorTransformerModel(pl.LightningModule):
def __init__(self, params): def __init__(self, params):
super().__init__() super().__init__()
self.save_hyperparameters(params) self.save_hyperparameters(params)
# self.a = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# self.b = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# Neural network layers # Neural network layers
self.network = nn.Sequential( self.network = nn.Sequential(
nn.Linear(3, self.hparams.width), nn.Linear(5, 64),
nn.ReLU(), nn.Tanh(),
nn.Linear(self.hparams.width, 64), nn.Linear(64, self.hparams.width),
nn.ReLU(), nn.Tanh(),
nn.Linear(64, 1), nn.Linear(self.hparams.width, 3),
nn.Tanh(),
nn.Linear(3, 1),
) )
def forward(self, x): def forward(self, x):
# Pass the input through the network # Pass the input through the network
# a = self.a(x)
# b = self.b(x)
# a = torch.sigmoid(a)
# b = torch.sigmoid(b)
# x = torch.cat([x, a, b], dim=-1)
x = self.network(x) x = self.network(x)
# Circular mapping # Circular mapping
# x = (torch.sin(x) + 1) / 2 # x = (torch.sin(x) + 1) / 2
@ -104,7 +127,14 @@ class ColorTransformerModel(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
s_loss = calculate_separation_loss(model=self) # s_loss = calculate_separation_loss(model=self)
# preserve distance to pure R, G, B. this acts kind of like labeled data.
s_loss = preservation_loss(
inputs,
outputs,
target_inputs=PURE_RGB,
target_outputs=PURE_HSV,
)
p_loss = preservation_loss( p_loss = preservation_loss(
inputs, inputs,
outputs, outputs,

8
search.py

@ -20,12 +20,12 @@ NUM_JOBS = 100
# Define the ranges or sets of values for each hyperparameter # Define the ranges or sets of values for each hyperparameter
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4)) # alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5)) # learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-2, 1e-3] learning_rate_values = [1e-2]
alpha_values = [0, 1, 2] alpha_values = [0, 1, 2]
widths = [64, 128, 256, 512] widths = [2**k for k in range(4, 15)]
# learning_rate_values = [5e-4] # learning_rate_values = [5e-4]
batch_size_values = [8192] batch_size_values = [256]
max_epochs_values = [50] max_epochs_values = [100]
seeds = list(range(21, 1992)) seeds = list(range(21, 1992))
# Generate all possible combinations of hyperparameters # Generate all possible combinations of hyperparameters

3
utils.py

@ -2,7 +2,7 @@ import matplotlib.colors as mcolors
import torch import torch
def preprocess_data(data, skip=True): def preprocess_data(data, skip: bool = False):
# Assuming 'data' is a tensor of shape [n_samples, 3] # Assuming 'data' is a tensor of shape [n_samples, 3]
if not skip: if not skip:
# Compute argmin and argmax for each row # Compute argmin and argmax for each row
@ -37,3 +37,4 @@ def extract_colors():
PURE_RGB = preprocess_data( PURE_RGB = preprocess_data(
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32) torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32)
) )
PURE_HSV = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)

Loading…
Cancel
Save