Browse Source

optimizations

new-sep-loss
Michael Pilosov 11 months ago
parent
commit
95af36221b
  1. 42
      losses.py
  2. 21
      main.py
  3. 2
      makefile
  4. 25
      model.py
  5. 26
      search.py

42
losses.py

@ -13,21 +13,45 @@ def weighted_loss(inputs, outputs, alpha):
return torch.mean(loss) return torch.mean(loss)
def enhanced_loss(inputs, outputs, alpha, distinct_threshold): # def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
# # Calculate RGB Norm
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# # Calculate 1D Space Norm
# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# # Identify Distinct Colors (based on a threshold in RGB space)
# distinct_colors = rgb_norm > distinct_threshold
# # Penalty for Distinct Colors being too close in the transformed space
# # Here we do not take the mean yet, to avoid double averaging
# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors
# # Combined Loss
# # The mean is taken here, once, after all components are combined
# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty
# return torch.mean(loss)
def enhanced_loss(inputs, outputs, alpha):
# Calculate RGB Norm # Calculate RGB Norm
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm # Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# Identify Distinct Colors (based on a threshold in RGB space) # Distance Preservation Component
distinct_colors = rgb_norm > distinct_threshold # Encourages the model to keep relative distances from the RGB space in the transformed space
distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm))
# Penalty for Distinct Colors being too close in the transformed space # Sort outputs for smoothness calculation
# Here we do not take the mean yet, to avoid double averaging sorted_outputs, _ = torch.sort(outputs, dim=0)
distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors
# Calculate smoothness in the sorted outputs
first_derivative = torch.diff(sorted_outputs, n=1, dim=0)
second_derivative = torch.diff(first_derivative, n=1, dim=0)
smoothness_loss = torch.mean(torch.abs(second_derivative))
# Combined Loss # Combined Loss
# The mean is taken here, once, after all components are combined loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty return loss
return torch.mean(loss)

21
main.py

@ -5,6 +5,7 @@ import pytorch_lightning as pl
from dataloader import create_named_dataloader as init_data from dataloader import create_named_dataloader as init_data
from model import ColorTransformerModel from model import ColorTransformerModel
def parse_args(): def parse_args():
# Define argument parser # Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script") parser = argparse.ArgumentParser(description="Color Transformer Training Script")
@ -19,9 +20,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
) )
parser.add_argument( parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
"--lr", type=float, default=1e-5, help="Learning rate"
)
parser.add_argument( parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
) )
@ -29,22 +28,19 @@ def parse_args():
"-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency" "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
) )
parser.add_argument( parser.add_argument(
"-w", "--num_workers", type=int, default=3, help="Number of workers for data loading" "-w",
) "--num_workers",
parser.add_argument( type=int,
"-D", default=3,
"--distinct_threshold", help="Number of workers for data loading",
type=float,
default=0.5,
help="Threshold for color distinctness penalty",
) )
# Parse arguments # Parse arguments
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == "__main__":
if __name__ == "__main__":
args = parse_args() args = parse_args()
# Initialize data loader with parsed arguments # Initialize data loader with parsed arguments
@ -57,7 +53,6 @@ if __name__ == "__main__":
# Initialize model with parsed arguments # Initialize model with parsed arguments
model = ColorTransformerModel( model = ColorTransformerModel(
alpha=args.alpha, alpha=args.alpha,
distinct_threshold=args.distinct_threshold,
learning_rate=args.lr, learning_rate=args.lr,
) )

2
makefile

@ -4,4 +4,4 @@ lint:
flake8 --ignore E501 . flake8 --ignore E501 .
test: test:
python main.py --alpha 0.7 --lr 1e-3 -D 0.5 --max_epochs 1000 python main.py --alpha 0.7 --lr 1e-3 --max_epochs 1000

25
model.py

@ -1,10 +1,10 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import enhanced_loss, weighted_loss from losses import enhanced_loss, weighted_loss
# class ColorTransformerModel(pl.LightningModule): # class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, distinct_threshold, learning_rate): # def __init__(self, alpha, distinct_threshold, learning_rate):
# super().__init__() # super().__init__()
@ -22,8 +22,9 @@ from losses import enhanced_loss, weighted_loss
# def forward(self, x): # def forward(self, x):
# return self.layers(x) # return self.layers(x)
class ColorTransformerModel(pl.LightningModule): class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, distinct_threshold, learning_rate): def __init__(self, alpha, learning_rate):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
@ -34,12 +35,13 @@ class ColorTransformerModel(pl.LightningModule):
transformer_layer = nn.TransformerEncoderLayer( transformer_layer = nn.TransformerEncoderLayer(
d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
) )
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=3) self.transformer_encoder = nn.TransformerEncoder(
transformer_layer, num_layers=3
)
# Final linear layer to map back to 1D space # Final linear layer to map back to 1D space
self.final_layer = nn.Linear(128, 1) self.final_layer = nn.Linear(128, 1)
def forward(self, x): def forward(self, x):
# Embedding the input # Embedding the input
x = self.embedding(x) x = self.embedding(x)
@ -56,6 +58,9 @@ class ColorTransformerModel(pl.LightningModule):
# Final linear layer # Final linear layer
x = self.final_layer(x) x = self.final_layer(x)
# Apply sigmoid activation to ensure output is in (0, 1)
x = torch.sigmoid(x)
return x return x
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
@ -66,11 +71,17 @@ class ColorTransformerModel(pl.LightningModule):
inputs, inputs,
outputs, outputs,
alpha=self.hparams.alpha, alpha=self.hparams.alpha,
distinct_threshold=self.hparams.distinct_threshold,
) )
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2)
return optimizer lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': lr_scheduler,
'monitor': 'train_loss', # Specify the metric to monitor
}
}

26
search.py

@ -1,6 +1,7 @@
from lightning_sdk import Studio, Machine
from random import sample from random import sample
from lightning_sdk import Machine, Studio
NUM_JOBS = 4 NUM_JOBS = 4
# reference to the current studio # reference to the current studio
@ -8,31 +9,32 @@ NUM_JOBS = 4
studio = Studio() studio = Studio()
# use the jobs plugin # use the jobs plugin
studio.install_plugin('jobs') studio.install_plugin("jobs")
job_plugin = studio.installed_plugins['jobs'] job_plugin = studio.installed_plugins["jobs"]
# do a sweep over learning rates # do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter # Define the ranges or sets of values for each hyperparameter
alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9] alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9]
distinct_threshold_values = [0.5, 0.6, 0.7, 0.8] learning_rate_values = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
learning_rate_values = [1e-6, 1-e5, 1e-4, 1e-3, 1e-2]
batch_size_values = [32, 64, 128] batch_size_values = [32, 64, 128]
max_epochs_values = [10000] max_epochs_values = [10000]
# Generate all possible combinations of hyperparameters # Generate all possible combinations of hyperparameters
all_params = [(alpha, dt, lr, bs, me) for alpha in alpha_values all_params = [
for dt in distinct_threshold_values (alpha, lr, bs, me)
for alpha in alpha_values
for lr in learning_rate_values for lr in learning_rate_values
for bs in batch_size_values for bs in batch_size_values
for me in max_epochs_values] for me in max_epochs_values
]
# perform random search with a limit # perform random search with a limit
random_search_params = sample(search_params, NUM_JOBS) search_params = sample(all_params, NUM_JOBS)
# start all jobs on an A10G GPU with names containing an index # start all jobs on an A10G GPU with names containing an index
for idx, (a, thresh, lr, bs, max_epochs) in enumerate(search_params): for idx, (a, lr, bs, me) in enumerate(search_params):
cmd = f'python main.py --alpha {a} -D {thresh} --lr {lr} --bs {bs} --max_epochs {max_epochs}' cmd = f"python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}"
job_name = f'color-exp-{idx}' job_name = f"color-exp-{idx}"
job_plugin.run(cmd, machine=Machine.T4, name=job_name) job_plugin.run(cmd, machine=Machine.T4, name=job_name)

Loading…
Cancel
Save