diff --git a/losses.py b/losses.py index ed78468..770001d 100644 --- a/losses.py +++ b/losses.py @@ -24,8 +24,10 @@ def enhanced_loss(inputs, outputs, alpha, distinct_threshold): distinct_colors = rgb_norm > distinct_threshold # Penalty for Distinct Colors being too close in the transformed space - distinct_penalty = torch.mean((1.0 / (transformed_norm + 1e-6)) * distinct_colors) + # Here we do not take the mean yet, to avoid double averaging + distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors # Combined Loss - loss = alpha * rgb_norm + transformed_norm + distinct_penalty + # The mean is taken here, once, after all components are combined + loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty return torch.mean(loss) diff --git a/main.py b/main.py index 5aa96f8..6f18d27 100644 --- a/main.py +++ b/main.py @@ -5,14 +5,13 @@ import pytorch_lightning as pl from dataloader import create_named_dataloader as init_data from model import ColorTransformerModel -if __name__ == "__main__": +def parse_args(): # Define argument parser parser = argparse.ArgumentParser(description="Color Transformer Training Script") # Add arguments parser.add_argument( - "-bs", - "--batch_size", + "--bs", type=int, default=64, help="Input batch size for training", @@ -21,19 +20,19 @@ if __name__ == "__main__": "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" ) parser.add_argument( - "-lr", "--learning_rate", type=float, default=1e-5, help="Learning rate" + "--lr", type=float, default=1e-5, help="Learning rate" ) parser.add_argument( "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" ) parser.add_argument( - "-log", "--log_every_n_steps", type=int, default=5, help="Logging frequency" + "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency" ) parser.add_argument( - "--num_workers", type=int, default=3, help="Number of workers for data loading" + "-w", "--num_workers", type=int, default=3, help="Number of workers for data loading" ) parser.add_argument( - "-ds", + "-D", "--distinct_threshold", type=float, default=0.5, @@ -42,10 +41,15 @@ if __name__ == "__main__": # Parse arguments args = parser.parse_args() + return args + +if __name__ == "__main__": + + args = parse_args() # Initialize data loader with parsed arguments train_dataloader = init_data( - batch_size=args.batch_size, + batch_size=args.bs, shuffle=True, num_workers=args.num_workers, ) @@ -54,7 +58,7 @@ if __name__ == "__main__": model = ColorTransformerModel( alpha=args.alpha, distinct_threshold=args.distinct_threshold, - learning_rate=args.learning_rate, + learning_rate=args.lr, ) # Initialize trainer with parsed arguments diff --git a/makefile b/makefile index ffe6a55..009f754 100644 --- a/makefile +++ b/makefile @@ -4,4 +4,4 @@ lint: flake8 --ignore E501 . test: - python main.py --alpha 0.7 -lr 1e-5 \ No newline at end of file + python main.py --alpha 0.7 --lr 1e-3 -D 0.5 --max_epochs 1000 \ No newline at end of file diff --git a/model.py b/model.py index 04c619d..168efa2 100644 --- a/model.py +++ b/model.py @@ -5,22 +5,58 @@ import torch.nn as nn from losses import enhanced_loss, weighted_loss +# class ColorTransformerModel(pl.LightningModule): +# def __init__(self, alpha, distinct_threshold, learning_rate): +# super().__init__() +# self.save_hyperparameters() + +# # Model layers +# self.layers = nn.Sequential( +# nn.Linear(3, 128), +# nn.ReLU(), +# nn.Linear(128, 128), +# nn.ReLU(), +# nn.Linear(128, 1), +# ) + +# def forward(self, x): +# return self.layers(x) + class ColorTransformerModel(pl.LightningModule): def __init__(self, alpha, distinct_threshold, learning_rate): super().__init__() self.save_hyperparameters() - # Model layers - self.layers = nn.Sequential( - nn.Linear(3, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 3), + # Embedding layer to expand the input dimensions + self.embedding = nn.Linear(3, 128) + + # Transformer block + transformer_layer = nn.TransformerEncoderLayer( + d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 ) + self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=3) + + # Final linear layer to map back to 1D space + self.final_layer = nn.Linear(128, 1) + def forward(self, x): - return self.layers(x) + # Embedding the input + x = self.embedding(x) + + # Adjusting the shape for the transformer + x = x.unsqueeze(1) # Adding a fake sequence dimension + + # Passing through the transformer + x = self.transformer_encoder(x) + + # Reshape back to original shape + x = x.squeeze(1) + + # Final linear layer + x = self.final_layer(x) + + return x def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings diff --git a/search.py b/search.py new file mode 100644 index 0000000..c51ae59 --- /dev/null +++ b/search.py @@ -0,0 +1,38 @@ +from lightning_sdk import Studio, Machine +from random import sample + +NUM_JOBS = 4 + +# reference to the current studio +# if you run outside of Lightning, you can pass the Studio name +studio = Studio() + +# use the jobs plugin +studio.install_plugin('jobs') +job_plugin = studio.installed_plugins['jobs'] + +# do a sweep over learning rates + +# Define the ranges or sets of values for each hyperparameter +alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9] +distinct_threshold_values = [0.5, 0.6, 0.7, 0.8] +learning_rate_values = [1e-6, 1-e5, 1e-4, 1e-3, 1e-2] +batch_size_values = [32, 64, 128] +max_epochs_values = [10000] + +# Generate all possible combinations of hyperparameters +all_params = [(alpha, dt, lr, bs, me) for alpha in alpha_values + for dt in distinct_threshold_values + for lr in learning_rate_values + for bs in batch_size_values + for me in max_epochs_values] + + +# perform random search with a limit +random_search_params = sample(search_params, NUM_JOBS) + +# start all jobs on an A10G GPU with names containing an index +for idx, (a, thresh, lr, bs, max_epochs) in enumerate(search_params): + cmd = f'python main.py --alpha {a} -D {thresh} --lr {lr} --bs {bs} --max_epochs {max_epochs}' + job_name = f'color-exp-{idx}' + job_plugin.run(cmd, machine=Machine.T4, name=job_name)