Browse Source

initial commit

new-sep-loss
Michael Pilosov 11 months ago
commit
9935640a81
  1. 3
      .gitignore
  2. 58
      dataloader.py
  3. 31
      losses.py
  4. 67
      main.py
  5. 7
      makefile
  6. 40
      model.py

3
.gitignore

@ -0,0 +1,3 @@
lightning_logs/
__pycache__/
.sw[opqr]

58
dataloader.py

@ -0,0 +1,58 @@
import matplotlib.colors as mcolors
import torch
from torch.utils.data import DataLoader, TensorDataset
def extract_colors():
# Extracting the list of xkcd colors as RGB triples
xkcd_colors = mcolors.XKCD_COLORS
rgb_values = [mcolors.to_rgb(color) for color in xkcd_colors.values()]
# Extracting the list of xkcd color names
xkcd_color_names = list(xkcd_colors.keys())
# Convert the list of RGB triples to a PyTorch tensor
rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32)
return rgb_tensor, xkcd_color_names
def create_dataloader(**kwargs):
rgb_tensor, _ = extract_colors()
# Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels
train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader
def create_named_dataloader(**kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor))
]
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names
if __name__ == "__main__":
batch_size = 4
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)
train_dataloader_with_names = create_named_dataloader(
batch_size=batch_size, shuffle=True
)
# Extract a sample from the DataLoader
sample_data = next(iter(train_dataloader))
# Sample RGB values and their corresponding dummy labels
sample_rgb_values, _ = sample_data
print(sample_rgb_values)
# Extract a sample from the new DataLoader
sample_data_with_names = next(iter(train_dataloader_with_names))
# Sample RGB values and their corresponding color names
sample_rgb_values_with_names, sample_color_names = sample_data_with_names
print(sample_rgb_values_with_names, sample_color_names)

31
losses.py

@ -0,0 +1,31 @@
import torch
def weighted_loss(inputs, outputs, alpha):
# Calculate RGB Norm (Perceptual Difference)
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# Weighted Loss
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm
return torch.mean(loss)
def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
# Calculate RGB Norm
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# Identify Distinct Colors (based on a threshold in RGB space)
distinct_colors = rgb_norm > distinct_threshold
# Penalty for Distinct Colors being too close in the transformed space
distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors)
# Combined Loss
loss = alpha * rgb_norm + transformed_norm + distinct_penalty
return torch.mean(loss)

67
main.py

@ -0,0 +1,67 @@
import argparse
import pytorch_lightning as pl
from dataloader import create_named_dataloader as init_data
from model import ColorTransformerModel
if __name__ == "__main__":
# Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script")
# Add arguments
parser.add_argument(
"-bs",
"--batch_size",
type=int,
default=64,
help="Input batch size for training",
)
parser.add_argument(
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
)
parser.add_argument(
"-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate"
)
parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
)
parser.add_argument(
"-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
)
parser.add_argument(
"--num_workers", type=int, default=3, help="Number of workers for data loading"
)
parser.add_argument(
"-ds",
"--distinct_threshold",
type=float,
default=0.5,
help="Threshold for color distinctness penalty",
)
# Parse arguments
args = parser.parse_args()
# Initialize data loader with parsed arguments
train_dataloader = init_data(
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
# Initialize model with parsed arguments
model = ColorTransformerModel(
alpha=args.alpha,
distinct_threshold=args.distinct_threshold,
learning_rate=args.learning_rate,
)
# Initialize trainer with parsed arguments
trainer = pl.Trainer(
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
)
# Train the model
trainer.fit(model, train_dataloader)

7
makefile

@ -0,0 +1,7 @@
lint:
black .
isort --profile=black .
flake8 --ignore E501 .
test:
python main.py --alpha 0.7 -lr 1e-5

40
model.py

@ -0,0 +1,40 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
from losses import enhanced_loss, weighted_loss
class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, distinct_threshold, learning_rate):
super().__init__()
self.save_hyperparameters()
# Model layers
self.layers = nn.Sequential(
nn.Linear(3, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 3),
)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha)
loss = enhanced_loss(
inputs,
outputs,
alpha=self.hparams.alpha,
distinct_threshold=self.hparams.distinct_threshold,
)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
Loading…
Cancel
Save