Michael Pilosov
11 months ago
commit
9935640a81
6 changed files with 206 additions and 0 deletions
@ -0,0 +1,3 @@ |
|||
lightning_logs/ |
|||
__pycache__/ |
|||
.sw[opqr] |
@ -0,0 +1,58 @@ |
|||
import matplotlib.colors as mcolors |
|||
import torch |
|||
from torch.utils.data import DataLoader, TensorDataset |
|||
|
|||
|
|||
def extract_colors(): |
|||
# Extracting the list of xkcd colors as RGB triples |
|||
xkcd_colors = mcolors.XKCD_COLORS |
|||
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()] |
|||
|
|||
# Extracting the list of xkcd color names |
|||
xkcd_color_names = list(xkcd_colors.keys()) |
|||
|
|||
# Convert the list of RGB triples to a PyTorch tensor |
|||
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32) |
|||
return rgb_tensor, xkcd_color_names |
|||
|
|||
|
|||
def create_dataloader(**kwargs): |
|||
rgb_tensor, _ = extract_colors() |
|||
# Creating a dataset and data loader |
|||
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels |
|||
train_dataloader = DataLoader(dataset, **kwargs) |
|||
return train_dataloader |
|||
|
|||
|
|||
def create_named_dataloader(**kwargs): |
|||
rgb_tensor, xkcd_color_names = extract_colors() |
|||
# Creating a dataset with RGB values and their corresponding color names |
|||
dataset_with_names = [ |
|||
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) |
|||
] |
|||
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) |
|||
return train_dataloader_with_names |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
batch_size = 4 |
|||
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) |
|||
train_dataloader_with_names = create_named_dataloader( |
|||
batch_size=batch_size, shuffle=True |
|||
) |
|||
|
|||
# Extract a sample from the DataLoader |
|||
sample_data = next(iter(train_dataloader)) |
|||
|
|||
# Sample RGB values and their corresponding dummy labels |
|||
sample_rgb_values, _ = sample_data |
|||
|
|||
print(sample_rgb_values) |
|||
|
|||
# Extract a sample from the new DataLoader |
|||
sample_data_with_names = next(iter(train_dataloader_with_names)) |
|||
|
|||
# Sample RGB values and their corresponding color names |
|||
sample_rgb_values_with_names, sample_color_names = sample_data_with_names |
|||
|
|||
print(sample_rgb_values_with_names, sample_color_names) |
@ -0,0 +1,31 @@ |
|||
import torch |
|||
|
|||
|
|||
def weighted_loss(inputs, outputs, alpha): |
|||
# Calculate RGB Norm (Perceptual Difference) |
|||
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|||
|
|||
# Calculate 1D Space Norm |
|||
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|||
|
|||
# Weighted Loss |
|||
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm |
|||
return torch.mean(loss) |
|||
|
|||
|
|||
def enhanced_loss(inputs, outputs, alpha, distinct_threshold): |
|||
# Calculate RGB Norm |
|||
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) |
|||
|
|||
# Calculate 1D Space Norm |
|||
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) |
|||
|
|||
# Identify Distinct Colors (based on a threshold in RGB space) |
|||
distinct_colors = rgb_norm > distinct_threshold |
|||
|
|||
# Penalty for Distinct Colors being too close in the transformed space |
|||
distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors) |
|||
|
|||
# Combined Loss |
|||
loss = alpha * rgb_norm + transformed_norm + distinct_penalty |
|||
return torch.mean(loss) |
@ -0,0 +1,67 @@ |
|||
import argparse |
|||
|
|||
import pytorch_lightning as pl |
|||
|
|||
from dataloader import create_named_dataloader as init_data |
|||
from model import ColorTransformerModel |
|||
|
|||
if __name__ == "__main__": |
|||
# Define argument parser |
|||
parser = argparse.ArgumentParser(description="Color Transformer Training Script") |
|||
|
|||
# Add arguments |
|||
parser.add_argument( |
|||
"-bs", |
|||
"--batch_size", |
|||
type=int, |
|||
default=64, |
|||
help="Input batch size for training", |
|||
) |
|||
parser.add_argument( |
|||
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" |
|||
) |
|||
parser.add_argument( |
|||
"-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate" |
|||
) |
|||
parser.add_argument( |
|||
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" |
|||
) |
|||
parser.add_argument( |
|||
"-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency" |
|||
) |
|||
parser.add_argument( |
|||
"--num_workers", type=int, default=3, help="Number of workers for data loading" |
|||
) |
|||
parser.add_argument( |
|||
"-ds", |
|||
"--distinct_threshold", |
|||
type=float, |
|||
default=0.5, |
|||
help="Threshold for color distinctness penalty", |
|||
) |
|||
|
|||
# Parse arguments |
|||
args = parser.parse_args() |
|||
|
|||
# Initialize data loader with parsed arguments |
|||
train_dataloader = init_data( |
|||
batch_size=args.batch_size, |
|||
shuffle=True, |
|||
num_workers=args.num_workers, |
|||
) |
|||
|
|||
# Initialize model with parsed arguments |
|||
model = ColorTransformerModel( |
|||
alpha=args.alpha, |
|||
distinct_threshold=args.distinct_threshold, |
|||
learning_rate=args.learning_rate, |
|||
) |
|||
|
|||
# Initialize trainer with parsed arguments |
|||
trainer = pl.Trainer( |
|||
max_epochs=args.max_epochs, |
|||
log_every_n_steps=args.log_every_n_steps, |
|||
) |
|||
|
|||
# Train the model |
|||
trainer.fit(model, train_dataloader) |
@ -0,0 +1,7 @@ |
|||
lint: |
|||
black . |
|||
isort --profile=black . |
|||
flake8 --ignore E501 . |
|||
|
|||
test: |
|||
python main.py --alpha 0.7 -lr 1e-5 |
@ -0,0 +1,40 @@ |
|||
import pytorch_lightning as pl |
|||
import torch |
|||
import torch.nn as nn |
|||
|
|||
from losses import enhanced_loss, weighted_loss |
|||
|
|||
|
|||
class ColorTransformerModel(pl.LightningModule): |
|||
def __init__(self, alpha, distinct_threshold, learning_rate): |
|||
super().__init__() |
|||
self.save_hyperparameters() |
|||
|
|||
# Model layers |
|||
self.layers = nn.Sequential( |
|||
nn.Linear(3, 128), |
|||
nn.ReLU(), |
|||
nn.Linear(128, 128), |
|||
nn.ReLU(), |
|||
nn.Linear(128, 3), |
|||
) |
|||
|
|||
def forward(self, x): |
|||
return self.layers(x) |
|||
|
|||
def training_step(self, batch, batch_idx): |
|||
inputs, labels = batch # x are the RGB inputs, labels are the strings |
|||
outputs = self.forward(inputs) |
|||
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) |
|||
loss = enhanced_loss( |
|||
inputs, |
|||
outputs, |
|||
alpha=self.hparams.alpha, |
|||
distinct_threshold=self.hparams.distinct_threshold, |
|||
) |
|||
self.log("train_loss", loss) |
|||
return loss |
|||
|
|||
def configure_optimizers(self): |
|||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) |
|||
return optimizer |
Loading…
Reference in new issue