From 9935640a81c4fb94a9ea2aa7464ef2f2a7dd019e Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Sat, 30 Dec 2023 04:37:06 +0000 Subject: [PATCH] initial commit --- .gitignore | 3 +++ dataloader.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ losses.py | 31 ++++++++++++++++++++++++ main.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++ makefile | 7 ++++++ model.py | 40 ++++++++++++++++++++++++++++++ 6 files changed, 206 insertions(+) create mode 100644 .gitignore create mode 100644 dataloader.py create mode 100644 losses.py create mode 100644 main.py create mode 100644 makefile create mode 100644 model.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4b2cf13 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +lightning_logs/ +__pycache__/ +.sw[opqr] diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..d574697 --- /dev/null +++ b/dataloader.py @@ -0,0 +1,58 @@ +import matplotlib.colors as mcolors +import torch +from torch.utils.data import DataLoader, TensorDataset + + +def extract_colors(): + # Extracting the list of xkcd colors as RGB triples + xkcd_colors = mcolors.XKCD_COLORS + rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] + + # Extracting the list of xkcd color names + xkcd_color_names = list(xkcd_colors.keys()) + + # Convert the list of RGB triples to a PyTorch tensor + rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) + return rgb_tensor, xkcd_color_names + + +def create_dataloader(**kwargs): + rgb_tensor, _ = extract_colors() + # Creating a dataset and data loader + dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels + train_dataloader = DataLoader(dataset, **kwargs) + return train_dataloader + + +def create_named_dataloader(**kwargs): + rgb_tensor, xkcd_color_names = extract_colors() + # Creating a dataset with RGB values and their corresponding color names + dataset_with_names = [ + (rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) + ] + train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) + return train_dataloader_with_names + + +if __name__ == "__main__": + batch_size = 4 + train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) + train_dataloader_with_names = create_named_dataloader( + batch_size=batch_size, shuffle=True + ) + + # Extract a sample from the DataLoader + sample_data = next(iter(train_dataloader)) + + # Sample RGB values and their corresponding dummy labels + sample_rgb_values, _ = sample_data + + print(sample_rgb_values) + + # Extract a sample from the new DataLoader + sample_data_with_names = next(iter(train_dataloader_with_names)) + + # Sample RGB values and their corresponding color names + sample_rgb_values_with_names, sample_color_names = sample_data_with_names + + print(sample_rgb_values_with_names, sample_color_names) diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..ed78468 --- /dev/null +++ b/losses.py @@ -0,0 +1,31 @@ +import torch + + +def weighted_loss(inputs, outputs, alpha): + # Calculate RGB Norm (Perceptual Difference) + rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) + + # Calculate 1D Space Norm + transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) + + # Weighted Loss + loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + return torch.mean(loss) + + +def enhanced_loss(inputs, outputs, alpha, distinct_threshold): + # Calculate RGB Norm + rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) + + # Calculate 1D Space Norm + transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) + + # Identify Distinct Colors (based on a threshold in RGB space) + distinct_colors = rgb_norm > distinct_threshold + + # Penalty for Distinct Colors being too close in the transformed space + distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors) + + # Combined Loss + loss = alpha * rgb_norm + transformed_norm + distinct_penalty + return torch.mean(loss) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5aa96f8 --- /dev/null +++ b/main.py @@ -0,0 +1,67 @@ +import argparse + +import pytorch_lightning as pl + +from dataloader import create_named_dataloader as init_data +from model import ColorTransformerModel + +if __name__ == "__main__": + # Define argument parser + parser = argparse.ArgumentParser(description="Color Transformer Training Script") + + # Add arguments + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=64, + help="Input batch size for training", + ) + parser.add_argument( + "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" + ) + parser.add_argument( + "-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate" + ) + parser.add_argument( + "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" + ) + parser.add_argument( + "-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency" + ) + parser.add_argument( + "--num_workers", type=int, default=3, help="Number of workers for data loading" + ) + parser.add_argument( + "-ds", + "--distinct_threshold", + type=float, + default=0.5, + help="Threshold for color distinctness penalty", + ) + + # Parse arguments + args = parser.parse_args() + + # Initialize data loader with parsed arguments + train_dataloader = init_data( + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + ) + + # Initialize model with parsed arguments + model = ColorTransformerModel( + alpha=args.alpha, + distinct_threshold=args.distinct_threshold, + learning_rate=args.learning_rate, + ) + + # Initialize trainer with parsed arguments + trainer = pl.Trainer( + max_epochs=args.max_epochs, + log_every_n_steps=args.log_every_n_steps, + ) + + # Train the model + trainer.fit(model, train_dataloader) diff --git a/makefile b/makefile new file mode 100644 index 0000000..ffe6a55 --- /dev/null +++ b/makefile @@ -0,0 +1,7 @@ +lint: + black . + isort --profile=black . + flake8 --ignore E501 . + +test: + python main.py --alpha 0.7 -lr 1e-5 \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..04c619d --- /dev/null +++ b/model.py @@ -0,0 +1,40 @@ +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from losses import enhanced_loss, weighted_loss + + +class ColorTransformerModel(pl.LightningModule): + def __init__(self, alpha, distinct_threshold, learning_rate): + super().__init__() + self.save_hyperparameters() + + # Model layers + self.layers = nn.Sequential( + nn.Linear(3, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 3), + ) + + def forward(self, x): + return self.layers(x) + + def training_step(self, batch, batch_idx): + inputs, labels = batch # x are the RGB inputs, labels are the strings + outputs = self.forward(inputs) + # loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) + loss = enhanced_loss( + inputs, + outputs, + alpha=self.hparams.alpha, + distinct_threshold=self.hparams.distinct_threshold, + ) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + return optimizer