Michael Pilosov
11 months ago
commit
9935640a81
6 changed files with 206 additions and 0 deletions
@ -0,0 +1,3 @@ |
|||||
|
lightning_logs/ |
||||
|
__pycache__/ |
||||
|
.sw[opqr] |
@ -0,0 +1,58 @@ |
|||||
|
import matplotlib.colors as mcolors |
||||
|
import torch |
||||
|
from torch.utils.data import DataLoader, TensorDataset |
||||
|
|
||||
|
|
||||
|
def extract_colors(): |
||||
|
# Extracting the list of xkcd colors as RGB triples |
||||
|
xkcd_colors = mcolors.XKCD_COLORS |
||||
|
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] |
||||
|
|
||||
|
# Extracting the list of xkcd color names |
||||
|
xkcd_color_names = list(xkcd_colors.keys()) |
||||
|
|
||||
|
# Convert the list of RGB triples to a PyTorch tensor |
||||
|
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) |
||||
|
return rgb_tensor, xkcd_color_names |
||||
|
|
||||
|
|
||||
|
def create_dataloader(**kwargs): |
||||
|
rgb_tensor, _ = extract_colors() |
||||
|
# Creating a dataset and data loader |
||||
|
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels |
||||
|
train_dataloader = DataLoader(dataset, **kwargs) |
||||
|
return train_dataloader |
||||
|
|
||||
|
|
||||
|
def create_named_dataloader(**kwargs): |
||||
|
rgb_tensor, xkcd_color_names = extract_colors() |
||||
|
# Creating a dataset with RGB values and their corresponding color names |
||||
|
dataset_with_names = [ |
||||
|
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) |
||||
|
] |
||||
|
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) |
||||
|
return train_dataloader_with_names |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
batch_size = 4 |
||||
|
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) |
||||
|
train_dataloader_with_names = create_named_dataloader( |
||||
|
batch_size=batch_size, shuffle=True |
||||
|
) |
||||
|
|
||||
|
# Extract a sample from the DataLoader |
||||
|
sample_data = next(iter(train_dataloader)) |
||||
|
|
||||
|
# Sample RGB values and their corresponding dummy labels |
||||
|
sample_rgb_values, _ = sample_data |
||||
|
|
||||
|
print(sample_rgb_values) |
||||
|
|
||||
|
# Extract a sample from the new DataLoader |
||||
|
sample_data_with_names = next(iter(train_dataloader_with_names)) |
||||
|
|
||||
|
# Sample RGB values and their corresponding color names |
||||
|
sample_rgb_values_with_names, sample_color_names = sample_data_with_names |
||||
|
|
||||
|
print(sample_rgb_values_with_names, sample_color_names) |
@ -0,0 +1,31 @@ |
|||||
|
import torch |
||||
|
|
||||
|
|
||||
|
def weighted_loss(inputs, outputs, alpha): |
||||
|
# Calculate RGB Norm (Perceptual Difference) |
||||
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
||||
|
|
||||
|
# Calculate 1D Space Norm |
||||
|
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
||||
|
|
||||
|
# Weighted Loss |
||||
|
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm |
||||
|
return torch.mean(loss) |
||||
|
|
||||
|
|
||||
|
def enhanced_loss(inputs, outputs, alpha, distinct_threshold): |
||||
|
# Calculate RGB Norm |
||||
|
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
||||
|
|
||||
|
# Calculate 1D Space Norm |
||||
|
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
||||
|
|
||||
|
# Identify Distinct Colors (based on a threshold in RGB space) |
||||
|
distinct_colors = rgb_norm > distinct_threshold |
||||
|
|
||||
|
# Penalty for Distinct Colors being too close in the transformed space |
||||
|
distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors) |
||||
|
|
||||
|
# Combined Loss |
||||
|
loss = alpha * rgb_norm + transformed_norm + distinct_penalty |
||||
|
return torch.mean(loss) |
@ -0,0 +1,67 @@ |
|||||
|
import argparse |
||||
|
|
||||
|
import pytorch_lightning as pl |
||||
|
|
||||
|
from dataloader import create_named_dataloader as init_data |
||||
|
from model import ColorTransformerModel |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
# Define argument parser |
||||
|
parser = argparse.ArgumentParser(description="Color Transformer Training Script") |
||||
|
|
||||
|
# Add arguments |
||||
|
parser.add_argument( |
||||
|
"-bs", |
||||
|
"--batch_size", |
||||
|
type=int, |
||||
|
default=64, |
||||
|
help="Input batch size for training", |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate" |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency" |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"--num_workers", type=int, default=3, help="Number of workers for data loading" |
||||
|
) |
||||
|
parser.add_argument( |
||||
|
"-ds", |
||||
|
"--distinct_threshold", |
||||
|
type=float, |
||||
|
default=0.5, |
||||
|
help="Threshold for color distinctness penalty", |
||||
|
) |
||||
|
|
||||
|
# Parse arguments |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
# Initialize data loader with parsed arguments |
||||
|
train_dataloader = init_data( |
||||
|
batch_size=args.batch_size, |
||||
|
shuffle=True, |
||||
|
num_workers=args.num_workers, |
||||
|
) |
||||
|
|
||||
|
# Initialize model with parsed arguments |
||||
|
model = ColorTransformerModel( |
||||
|
alpha=args.alpha, |
||||
|
distinct_threshold=args.distinct_threshold, |
||||
|
learning_rate=args.learning_rate, |
||||
|
) |
||||
|
|
||||
|
# Initialize trainer with parsed arguments |
||||
|
trainer = pl.Trainer( |
||||
|
max_epochs=args.max_epochs, |
||||
|
log_every_n_steps=args.log_every_n_steps, |
||||
|
) |
||||
|
|
||||
|
# Train the model |
||||
|
trainer.fit(model, train_dataloader) |
@ -0,0 +1,7 @@ |
|||||
|
lint: |
||||
|
black . |
||||
|
isort --profile=black . |
||||
|
flake8 --ignore E501 . |
||||
|
|
||||
|
test: |
||||
|
python main.py --alpha 0.7 -lr 1e-5 |
@ -0,0 +1,40 @@ |
|||||
|
import pytorch_lightning as pl |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
|
||||
|
from losses import enhanced_loss, weighted_loss |
||||
|
|
||||
|
|
||||
|
class ColorTransformerModel(pl.LightningModule): |
||||
|
def __init__(self, alpha, distinct_threshold, learning_rate): |
||||
|
super().__init__() |
||||
|
self.save_hyperparameters() |
||||
|
|
||||
|
# Model layers |
||||
|
self.layers = nn.Sequential( |
||||
|
nn.Linear(3, 128), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(128, 128), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(128, 3), |
||||
|
) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.layers(x) |
||||
|
|
||||
|
def training_step(self, batch, batch_idx): |
||||
|
inputs, labels = batch # x are the RGB inputs, labels are the strings |
||||
|
outputs = self.forward(inputs) |
||||
|
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) |
||||
|
loss = enhanced_loss( |
||||
|
inputs, |
||||
|
outputs, |
||||
|
alpha=self.hparams.alpha, |
||||
|
distinct_threshold=self.hparams.distinct_threshold, |
||||
|
) |
||||
|
self.log("train_loss", loss) |
||||
|
return loss |
||||
|
|
||||
|
def configure_optimizers(self): |
||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) |
||||
|
return optimizer |
Loading…
Reference in new issue