Browse Source

chkpt

new-sep-loss
Michael Pilosov 11 months ago
parent
commit
d34a6b55d0
  1. 6
      losses.py
  2. 22
      main.py
  3. 2
      makefile
  4. 52
      model.py
  5. 38
      search.py

6
losses.py

@ -24,8 +24,10 @@ def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
distinct_colors = rgb_norm > distinct_threshold distinct_colors = rgb_norm > distinct_threshold
# Penalty for Distinct Colors being too close in the transformed space # Penalty for Distinct Colors being too close in the transformed space
distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors) # Here we do not take the mean yet, to avoid double averaging
distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors
# Combined Loss # Combined Loss
loss = alpha * rgb_norm + transformed_norm + distinct_penalty # The mean is taken here, once, after all components are combined
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty
return torch.mean(loss) return torch.mean(loss)

22
main.py

@ -5,14 +5,13 @@ import pytorch_lightning as pl
from dataloader import create_named_dataloader as init_data from dataloader import create_named_dataloader as init_data
from model import ColorTransformerModel from model import ColorTransformerModel
if __name__ == "__main__": def parse_args():
# Define argument parser # Define argument parser
parser = argparse.ArgumentParser(description="Color Transformer Training Script") parser = argparse.ArgumentParser(description="Color Transformer Training Script")
# Add arguments # Add arguments
parser.add_argument( parser.add_argument(
"-bs", "--bs",
"--batch_size",
type=int, type=int,
default=64, default=64,
help="Input batch size for training", help="Input batch size for training",
@ -21,19 +20,19 @@ if __name__ == "__main__":
"-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function"
) )
parser.add_argument( parser.add_argument(
"-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate" "--lr", type=float, default=1e-5, help="Learning rate"
) )
parser.add_argument( parser.add_argument(
"-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train"
) )
parser.add_argument( parser.add_argument(
"-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency" "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency"
) )
parser.add_argument( parser.add_argument(
"--num_workers", type=int, default=3, help="Number of workers for data loading" "-w", "--num_workers", type=int, default=3, help="Number of workers for data loading"
) )
parser.add_argument( parser.add_argument(
"-ds", "-D",
"--distinct_threshold", "--distinct_threshold",
type=float, type=float,
default=0.5, default=0.5,
@ -42,10 +41,15 @@ if __name__ == "__main__":
# Parse arguments # Parse arguments
args = parser.parse_args() args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
# Initialize data loader with parsed arguments # Initialize data loader with parsed arguments
train_dataloader = init_data( train_dataloader = init_data(
batch_size=args.batch_size, batch_size=args.bs,
shuffle=True, shuffle=True,
num_workers=args.num_workers, num_workers=args.num_workers,
) )
@ -54,7 +58,7 @@ if __name__ == "__main__":
model = ColorTransformerModel( model = ColorTransformerModel(
alpha=args.alpha, alpha=args.alpha,
distinct_threshold=args.distinct_threshold, distinct_threshold=args.distinct_threshold,
learning_rate=args.learning_rate, learning_rate=args.lr,
) )
# Initialize trainer with parsed arguments # Initialize trainer with parsed arguments

2
makefile

@ -4,4 +4,4 @@ lint:
flake8 --ignore E501 . flake8 --ignore E501 .
test: test:
python main.py --alpha 0.7 -lr 1e-5 python main.py --alpha 0.7 --lr 1e-3 -D 0.5 --max_epochs 1000

52
model.py

@ -5,22 +5,58 @@ import torch.nn as nn
from losses import enhanced_loss, weighted_loss from losses import enhanced_loss, weighted_loss
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, distinct_threshold, learning_rate):
# super().__init__()
# self.save_hyperparameters()
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(3, 128),
# nn.ReLU(),
# nn.Linear(128, 128),
# nn.ReLU(),
# nn.Linear(128, 1),
# )
# def forward(self, x):
# return self.layers(x)
class ColorTransformerModel(pl.LightningModule): class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, distinct_threshold, learning_rate): def __init__(self, alpha, distinct_threshold, learning_rate):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
# Model layers # Embedding layer to expand the input dimensions
self.layers = nn.Sequential( self.embedding = nn.Linear(3, 128)
nn.Linear(3, 128),
nn.ReLU(), # Transformer block
nn.Linear(128, 128), transformer_layer = nn.TransformerEncoderLayer(
nn.ReLU(), d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
nn.Linear(128, 3),
) )
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=3)
# Final linear layer to map back to 1D space
self.final_layer = nn.Linear(128, 1)
def forward(self, x): def forward(self, x):
return self.layers(x) # Embedding the input
x = self.embedding(x)
# Adjusting the shape for the transformer
x = x.unsqueeze(1) # Adding a fake sequence dimension
# Passing through the transformer
x = self.transformer_encoder(x)
# Reshape back to original shape
x = x.squeeze(1)
# Final linear layer
x = self.final_layer(x)
return x
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings

38
search.py

@ -0,0 +1,38 @@
from lightning_sdk import Studio, Machine
from random import sample
NUM_JOBS = 4
# reference to the current studio
# if you run outside of Lightning, you can pass the Studio name
studio = Studio()
# use the jobs plugin
studio.install_plugin('jobs')
job_plugin = studio.installed_plugins['jobs']
# do a sweep over learning rates
# Define the ranges or sets of values for each hyperparameter
alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9]
distinct_threshold_values = [0.5, 0.6, 0.7, 0.8]
learning_rate_values = [1e-6, 1-e5, 1e-4, 1e-3, 1e-2]
batch_size_values = [32, 64, 128]
max_epochs_values = [10000]
# Generate all possible combinations of hyperparameters
all_params = [(alpha, dt, lr, bs, me) for alpha in alpha_values
for dt in distinct_threshold_values
for lr in learning_rate_values
for bs in batch_size_values
for me in max_epochs_values]
# perform random search with a limit
random_search_params = sample(search_params, NUM_JOBS)
# start all jobs on an A10G GPU with names containing an index
for idx, (a, thresh, lr, bs, max_epochs) in enumerate(search_params):
cmd = f'python main.py --alpha {a} -D {thresh} --lr {lr} --bs {bs} --max_epochs {max_epochs}'
job_name = f'color-exp-{idx}'
job_plugin.run(cmd, machine=Machine.T4, name=job_name)
Loading…
Cancel
Save