diff --git a/check.py b/check.py index 60f5b71..92fc646 100644 --- a/check.py +++ b/check.py @@ -2,9 +2,10 @@ import matplotlib.pyplot as plt import numpy as np import torch -from dataloader import extract_colors +from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel + def make_image(ckpt: str, fname: str, color=True): M = ColorTransformerModel.load_from_checkpoint(ckpt) @@ -17,14 +18,15 @@ def make_image(ckpt: str, fname: str, color=True): rgb_tensor, names = extract_colors() rgb_values = rgb_tensor.detach().numpy() + rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor) sorted_inds = np.argsort(preds.detach().numpy().ravel()) - fig, ax = plt.subplots(figsize=(10, 5)) + fig, ax = plt.subplots(figsize=(20, 5)) for i in range(len(sorted_inds)): idx = sorted_inds[i] color = rgb_values[idx] - ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5) + ax.plot([i, i],[0, 5], lw=0.5, c=color, antialiased=False, alpha=1) ax.axis("off") # ax.axis("square") @@ -32,7 +34,8 @@ def make_image(ckpt: str, fname: str, color=True): if __name__ == "__main__": - - name = "color_128_0.3_1.00e-06" - ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" + # name = "color_128_0.3_1.00e-06" + name = "color_64_1_1.0e-3.png" + # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" + ckpt = "/teamspace/studios/this_studio/colors/lightning_logs/version_26/checkpoints/epoch=99-step=1500.ckpt" make_image(ckpt, fname=name) diff --git a/dataloader.py b/dataloader.py index d574697..9fabf23 100644 --- a/dataloader.py +++ b/dataloader.py @@ -18,14 +18,16 @@ def extract_colors(): def create_dataloader(**kwargs): rgb_tensor, _ = extract_colors() + rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset and data loader - dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels + dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) train_dataloader = DataLoader(dataset, **kwargs) return train_dataloader def create_named_dataloader(**kwargs): rgb_tensor, xkcd_color_names = extract_colors() + rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset with RGB values and their corresponding color names dataset_with_names = [ (rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) @@ -34,6 +36,24 @@ def create_named_dataloader(**kwargs): return train_dataloader_with_names +def preprocess_data(data): + # Assuming 'data' is a tensor of shape [n_samples, 3] + + # Compute argmin and argmax for each row + argmin_values = torch.argmin(data, dim=1, keepdim=True).float() + argmax_values = torch.argmax(data, dim=1, keepdim=True).float() + + # Normalize or scale argmin and argmax if necessary + # For example, here I am just dividing by the number of features + argmin_values /= data.shape[1] + argmax_values /= data.shape[1] + + # Concatenate the argmin and argmax values to the original data + new_data = torch.cat((data, argmin_values, argmax_values), dim=1) + + return new_data + + if __name__ == "__main__": batch_size = 4 train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) diff --git a/losses.py b/losses.py index f237949..7d15f07 100644 --- a/losses.py +++ b/losses.py @@ -1,16 +1,15 @@ import torch +# def weighted_loss(inputs, outputs, alpha): +# # Calculate RGB Norm (Perceptual Difference) +# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) -def weighted_loss(inputs, outputs, alpha): - # Calculate RGB Norm (Perceptual Difference) - rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) - - # Calculate 1D Space Norm - transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) +# # Calculate 1D Space Norm +# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) - # Weighted Loss - loss = alpha * rgb_norm + (1 - alpha) * transformed_norm - return torch.mean(loss) +# # Weighted Loss +# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm +# return torch.mean(loss) # def enhanced_loss(inputs, outputs, alpha, distinct_threshold): @@ -33,7 +32,7 @@ def weighted_loss(inputs, outputs, alpha): # return torch.mean(loss) -def enhanced_loss(inputs, outputs, alpha): +def preservation_loss(inputs, outputs): # Calculate RGB Norm rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) @@ -42,19 +41,19 @@ def enhanced_loss(inputs, outputs, alpha): # Distance Preservation Component # Encourages the model to keep relative distances from the RGB space in the transformed space - distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) - - # Combined Loss - loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs) - return loss + return torch.mean(torch.abs(rgb_norm - transformed_norm)) def smoothness_loss(outputs): # Sort outputs for smoothness calculation sorted_outputs, _ = torch.sort(outputs, dim=0) + first_elements = sorted_outputs[:2] + + # Concatenate the first element at the end of the sorted_outputs + extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0) # Calculate smoothness in the sorted outputs - first_derivative = torch.diff(sorted_outputs, n=1, dim=0) + first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0) smoothness_loss = torch.mean(torch.abs(second_derivative)) return smoothness_loss diff --git a/main.py b/main.py index 0864bdf..a7d5fd5 100644 --- a/main.py +++ b/main.py @@ -50,12 +50,15 @@ if __name__ == "__main__": num_workers=args.num_workers, ) - # Initialize model with parsed arguments - model = ColorTransformerModel( + params = argparse.Namespace( alpha=args.alpha, learning_rate=args.lr, + batch_size=args.bs, ) + # Initialize model with parsed arguments + model = ColorTransformerModel(params) + # Initialize trainer with parsed arguments trainer = pl.Trainer( max_epochs=args.max_epochs, diff --git a/makefile b/makefile index b0825f4..ed2be99 100644 --- a/makefile +++ b/makefile @@ -4,4 +4,4 @@ lint: flake8 --ignore E501 . test: - python main.py --alpha 0.7 --lr 1e-3 --max_epochs 1000 \ No newline at end of file + python main.py --alpha 1 --lr 1e-3 --max_epochs 100 \ No newline at end of file diff --git a/model.py b/model.py index fd7fc92..e459869 100644 --- a/model.py +++ b/model.py @@ -3,17 +3,17 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import enhanced_loss, weighted_loss # noqa: F401 +from losses import preservation_loss, smoothness_loss class ColorTransformerModel(pl.LightningModule): - def __init__(self, alpha, learning_rate): + def __init__(self, params): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(params) # Model layers self.layers = nn.Sequential( - nn.Linear(3, 128), + nn.Linear(5, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), @@ -22,7 +22,7 @@ class ColorTransformerModel(pl.LightningModule): def forward(self, x): x = self.layers(x) - x = torch.sigmoid(x) + x = (torch.sin(x) + 1) / 2 return x # class ColorTransformerModel(pl.LightningModule): @@ -68,12 +68,14 @@ class ColorTransformerModel(pl.LightningModule): def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings outputs = self.forward(inputs) - # loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) - loss = enhanced_loss( + s_loss = smoothness_loss(outputs) + p_loss = preservation_loss( inputs, outputs, - alpha=self.hparams.alpha, ) + alpha = self.hparams.alpha + loss = p_loss + alpha * s_loss + self.log("hp_metric", p_loss) self.log("train_loss", loss) return loss diff --git a/scrape.py b/scrape.py index 1da8248..90f0685 100644 --- a/scrape.py +++ b/scrape.py @@ -1,6 +1,7 @@ import glob -from pathlib import Path import shutil +from pathlib import Path + from check import make_image @@ -9,16 +10,18 @@ def get_exps(pattern: str, splitter: str = "_"): chkpt_basedir = "/work/colors/lightning_logs/" location = basedir + pattern res = glob.glob(location) - location = location.replace('*', '') + location = location.replace("*", "") H = [] # hyperparams used # print(res) for r in res: - d = r.replace(location, '').split(splitter) + d = r.replace(location, "").split(splitter) d = list(float(_d) for _d in d) d[0] = int(d[0]) H.append(d) for i, r in enumerate(res): - dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/") + dir_path = Path( + f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/" + ) dir_path.mkdir(parents=True, exist_ok=True) g = glob.glob(r + chkpt_basedir + "*") c = g[0] + "/checkpoints" @@ -26,7 +29,7 @@ def get_exps(pattern: str, splitter: str = "_"): # print(latest_checkpoint) logs = glob.glob(g[0] + "/events*")[-1] print(logs) - source_path = Path(logs) + # source_path = Path(logs) # print("Would copy", source_path, dir_path) # shutil.copy(source_path, dir_path) make_image(latest_checkpoint, f"out/version_{i}")