From 70ecd7d7dbd4c449ebe6622118b2c509295d9d97 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Tue, 16 Jan 2024 04:37:22 +0000 Subject: [PATCH] another attempt --- check.py | 4 ++-- dataloader.py | 14 +++++++------- losses.py | 18 ++++++++++++------ main.py | 25 +++++++++++++------------ makefile | 2 +- model.py | 46 ++++++++++++++++++++++++++++++++++++++-------- search.py | 8 ++++---- utils.py | 3 ++- 8 files changed, 79 insertions(+), 41 deletions(-) diff --git a/check.py b/check.py index 693daf2..025bdde 100644 --- a/check.py +++ b/check.py @@ -39,7 +39,7 @@ def make_image(ckpt: str, fname: str, color=True): plt.savefig(f"{fname}.png", dpi=300) -def create_circle(ckpt: str, fname: str, dpi: int = 150): +def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint(ckpt) else: @@ -57,7 +57,7 @@ def plot_preds( if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) - colors = rgb_values[sorted_inds, :] + colors = rgb_values[sorted_inds, :3] if roll: # find white in colors, put it first. white = np.array([1, 1, 1]) diff --git a/dataloader.py b/dataloader.py index bb70154..63120ac 100644 --- a/dataloader.py +++ b/dataloader.py @@ -4,32 +4,32 @@ from torch.utils.data import DataLoader, TensorDataset from utils import extract_colors, preprocess_data -def create_dataloader(N: int = 1e8, **kwargs): +def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs): rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32) - rgb_tensor = preprocess_data(rgb_tensor) + rgb_tensor = preprocess_data(rgb_tensor, skip=skip) # Creating a dataset and data loader dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) train_dataloader = DataLoader(dataset, **kwargs) return train_dataloader -def create_gray_supplement(N: int = 50): +def create_gray_supplement(N: int = 50, skip: bool = True): linear_space = torch.linspace(0, 1, N) gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) - gray_tensor = preprocess_data(gray_tensor) + gray_tensor = preprocess_data(gray_tensor, skip=skip) return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] -def create_named_dataloader(N: int = 0, **kwargs): +def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs): rgb_tensor, xkcd_color_names = extract_colors() - rgb_tensor = preprocess_data(rgb_tensor) + rgb_tensor = preprocess_data(rgb_tensor, skip=skip) # Creating a dataset with RGB values and their corresponding color names dataset_with_names = [ (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) for i in range(len(rgb_tensor)) ] if N > 0: - dataset_with_names += create_gray_supplement(N) + dataset_with_names += create_gray_supplement(N, skip=skip) train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) return train_dataloader_with_names diff --git a/losses.py b/losses.py index 30dfde4..3de252f 100644 --- a/losses.py +++ b/losses.py @@ -17,24 +17,30 @@ from utils import PURE_RGB # return smoothness_loss -def preservation_loss(inputs, outputs): +def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None): # Distance Preservation Component # Encourages the model to keep relative distances from the RGB space in the transformed space + if target_inputs is None: + target_inputs = inputs + else: + assert target_outputs is not None + if target_outputs is None: + target_outputs = outputs # Calculate RGB Norm max_rgb_distance = torch.sqrt(torch.tensor(2 + 1)) # scale to [0, 1] + # max_rgb_distance = 1 rgb_norm = ( - torch.triu(torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)) + torch.triu(torch.norm(inputs[:, None, :] - target_inputs[None, :, :], dim=-1)) / max_rgb_distance ) - rgb_norm = ( - rgb_norm % 1 - ) # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space + # connect (0, 0, 0) and (1, 1, 1): max_rgb_distance in the RGB space + rgb_norm = rgb_norm % 1 # print(rgb_norm) # Calculate 1D Space Norm (modulo 1 to account for circularity) transformed_norm = torch.triu( - torch.norm((outputs[:, None] - outputs[None, :]) % 1, dim=-1) + torch.norm((outputs[:, None] - target_outputs[None, :]) % 1, dim=-1) ) diff = torch.abs(rgb_norm - transformed_norm) diff --git a/main.py b/main.py index 1c31d5a..2fe83f9 100644 --- a/main.py +++ b/main.py @@ -4,10 +4,10 @@ import random import numpy as np import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping # noqa: F401 from callbacks import SaveImageCallback -from dataloader import create_dataloader +from dataloader import create_named_dataloader as create_dataloader from model import ColorTransformerModel @@ -62,23 +62,24 @@ if __name__ == "__main__": seed_everything(args.seed) - early_stop_callback = EarlyStopping( - monitor="hp_metric", # Metric to monitor - min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement - patience=5, # Number of epochs with no improvement after which training will be stopped - mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. - verbose=True, - ) + # early_stop_callback = EarlyStopping( + # monitor="hp_metric", # Metric to monitor + # min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement + # patience=5, # Number of epochs with no improvement after which training will be stopped + # mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. + # verbose=True, + # ) save_img_callback = SaveImageCallback( save_interval=0, - final_dir=None, + final_dir="out", ) # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed train_dataloader = create_dataloader( - N=1e5, + # N=1e5, + skip=False, batch_size=args.bs, shuffle=True, num_workers=args.num_workers, @@ -97,7 +98,7 @@ if __name__ == "__main__": # Initialize trainer with parsed arguments trainer = pl.Trainer( deterministic=True, - callbacks=[early_stop_callback, save_img_callback], + callbacks=[save_img_callback], max_epochs=args.max_epochs, log_every_n_steps=args.log_every_n_steps, ) diff --git a/makefile b/makefile index a10d80d..e2fd771 100644 --- a/makefile +++ b/makefile @@ -4,7 +4,7 @@ lint: flake8 --ignore E501,W503 . test: - python main.py --alpha 2 --lr 1e-3 --max_epochs 200 --bs 16384 --seed 1914 + python main.py --alpha 1 --lr 1e-2 --max_epochs 200 --bs 256 --seed 856 --width 2048 search: python search.py diff --git a/model.py b/model.py index 30e3658..e24415f 100644 --- a/model.py +++ b/model.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import calculate_separation_loss, preservation_loss +from losses import calculate_separation_loss, preservation_loss # noqa: F401 +from utils import PURE_HSV, PURE_RGB # class ColorTransformerModel(pl.LightningModule): # def __init__(self, params): @@ -83,18 +84,40 @@ class ColorTransformerModel(pl.LightningModule): def __init__(self, params): super().__init__() self.save_hyperparameters(params) - + # self.a = nn.Sequential( + # nn.Linear(3, 3, bias=False), + # nn.ReLU(), + # nn.Linear(3, 3, bias=False), + # nn.ReLU(), + # nn.Linear(3, 1, bias=False), + # nn.ReLU(), + # ) + # self.b = nn.Sequential( + # nn.Linear(3, 3, bias=False), + # nn.ReLU(), + # nn.Linear(3, 3, bias=False), + # nn.ReLU(), + # nn.Linear(3, 1, bias=False), + # nn.ReLU(), + # ) # Neural network layers self.network = nn.Sequential( - nn.Linear(3, self.hparams.width), - nn.ReLU(), - nn.Linear(self.hparams.width, 64), - nn.ReLU(), - nn.Linear(64, 1), + nn.Linear(5, 64), + nn.Tanh(), + nn.Linear(64, self.hparams.width), + nn.Tanh(), + nn.Linear(self.hparams.width, 3), + nn.Tanh(), + nn.Linear(3, 1), ) def forward(self, x): # Pass the input through the network + # a = self.a(x) + # b = self.b(x) + # a = torch.sigmoid(a) + # b = torch.sigmoid(b) + # x = torch.cat([x, a, b], dim=-1) x = self.network(x) # Circular mapping # x = (torch.sin(x) + 1) / 2 @@ -104,7 +127,14 @@ class ColorTransformerModel(pl.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - s_loss = calculate_separation_loss(model=self) + # s_loss = calculate_separation_loss(model=self) + # preserve distance to pure R, G, B. this acts kind of like labeled data. + s_loss = preservation_loss( + inputs, + outputs, + target_inputs=PURE_RGB, + target_outputs=PURE_HSV, + ) p_loss = preservation_loss( inputs, outputs, diff --git a/search.py b/search.py index 6b0b152..878f87a 100644 --- a/search.py +++ b/search.py @@ -20,12 +20,12 @@ NUM_JOBS = 100 # Define the ranges or sets of values for each hyperparameter # alpha_values = list(np.round(np.linspace(2, 4, 21), 4)) # learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5)) -learning_rate_values = [1e-2, 1e-3] +learning_rate_values = [1e-2] alpha_values = [0, 1, 2] -widths = [64, 128, 256, 512] +widths = [2**k for k in range(4, 15)] # learning_rate_values = [5e-4] -batch_size_values = [8192] -max_epochs_values = [50] +batch_size_values = [256] +max_epochs_values = [100] seeds = list(range(21, 1992)) # Generate all possible combinations of hyperparameters diff --git a/utils.py b/utils.py index 4bea669..3dcc80d 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,7 @@ import matplotlib.colors as mcolors import torch -def preprocess_data(data, skip=True): +def preprocess_data(data, skip: bool = False): # Assuming 'data' is a tensor of shape [n_samples, 3] if not skip: # Compute argmin and argmax for each row @@ -37,3 +37,4 @@ def extract_colors(): PURE_RGB = preprocess_data( torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32) ) +PURE_HSV = torch.tensor([[0], [1 / 3], [2 / 3]], dtype=torch.float32)