diff --git a/losses.py b/losses.py index 770001d..98badd7 100644 --- a/losses.py +++ b/losses.py @@ -13,21 +13,45 @@ def weighted_loss(inputs, outputs, alpha): return torch.mean(loss) -def enhanced_loss(inputs, outputs, alpha, distinct_threshold): +# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): +# # Calculate RGB Norm +# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) + +# # Calculate 1D Space Norm +# transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) + +# # Identify Distinct Colors (based on a threshold in RGB space) +# distinct_colors = rgb_norm > distinct_threshold + +# # Penalty for Distinct Colors being too close in the transformed space +# # Here we do not take the mean yet, to avoid double averaging +# distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors + +# # Combined Loss +# # The mean is taken here, once, after all components are combined +# loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty +# return torch.mean(loss) + + +def enhanced_loss(inputs, outputs, alpha): # Calculate RGB Norm rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) # Calculate 1D Space Norm transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1) - # Identify Distinct Colors (based on a threshold in RGB space) - distinct_colors = rgb_norm > distinct_threshold + # Distance Preservation Component + # Encourages the model to keep relative distances from the RGB space in the transformed space + distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) - # Penalty for Distinct Colors being too close in the transformed space - # Here we do not take the mean yet, to avoid double averaging - distinct_penalty = (1.0 / (transformed_norm + 1e-6)) * distinct_colors + # Sort outputs for smoothness calculation + sorted_outputs, _ = torch.sort(outputs, dim=0) + + # Calculate smoothness in the sorted outputs + first_derivative = torch.diff(sorted_outputs, n=1, dim=0) + second_derivative = torch.diff(first_derivative, n=1, dim=0) + smoothness_loss = torch.mean(torch.abs(second_derivative)) # Combined Loss - # The mean is taken here, once, after all components are combined - loss = alpha * rgb_norm + (1 - alpha) * transformed_norm + distinct_penalty - return torch.mean(loss) + loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss + return loss diff --git a/main.py b/main.py index 6f18d27..0864bdf 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import pytorch_lightning as pl from dataloader import create_named_dataloader as init_data from model import ColorTransformerModel + def parse_args(): # Define argument parser parser = argparse.ArgumentParser(description="Color Transformer Training Script") @@ -19,9 +20,7 @@ def parse_args(): parser.add_argument( "-a", "--alpha", type=float, default=0.5, help="Alpha value for loss function" ) - parser.add_argument( - "--lr", type=float, default=1e-5, help="Learning rate" - ) + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") parser.add_argument( "-e", "--max_epochs", type=int, default=1000, help="Number of epochs to train" ) @@ -29,22 +28,19 @@ def parse_args(): "-L", "--log_every_n_steps", type=int, default=5, help="Logging frequency" ) parser.add_argument( - "-w", "--num_workers", type=int, default=3, help="Number of workers for data loading" - ) - parser.add_argument( - "-D", - "--distinct_threshold", - type=float, - default=0.5, - help="Threshold for color distinctness penalty", + "-w", + "--num_workers", + type=int, + default=3, + help="Number of workers for data loading", ) # Parse arguments args = parser.parse_args() return args -if __name__ == "__main__": +if __name__ == "__main__": args = parse_args() # Initialize data loader with parsed arguments @@ -57,7 +53,6 @@ if __name__ == "__main__": # Initialize model with parsed arguments model = ColorTransformerModel( alpha=args.alpha, - distinct_threshold=args.distinct_threshold, learning_rate=args.lr, ) diff --git a/makefile b/makefile index 009f754..b0825f4 100644 --- a/makefile +++ b/makefile @@ -4,4 +4,4 @@ lint: flake8 --ignore E501 . test: - python main.py --alpha 0.7 --lr 1e-3 -D 0.5 --max_epochs 1000 \ No newline at end of file + python main.py --alpha 0.7 --lr 1e-3 --max_epochs 1000 \ No newline at end of file diff --git a/model.py b/model.py index 168efa2..2944ead 100644 --- a/model.py +++ b/model.py @@ -1,10 +1,10 @@ import pytorch_lightning as pl import torch import torch.nn as nn +from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import enhanced_loss, weighted_loss - # class ColorTransformerModel(pl.LightningModule): # def __init__(self, alpha, distinct_threshold, learning_rate): # super().__init__() @@ -22,8 +22,9 @@ from losses import enhanced_loss, weighted_loss # def forward(self, x): # return self.layers(x) + class ColorTransformerModel(pl.LightningModule): - def __init__(self, alpha, distinct_threshold, learning_rate): + def __init__(self, alpha, learning_rate): super().__init__() self.save_hyperparameters() @@ -34,12 +35,13 @@ class ColorTransformerModel(pl.LightningModule): transformer_layer = nn.TransformerEncoderLayer( d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 ) - self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=3) + self.transformer_encoder = nn.TransformerEncoder( + transformer_layer, num_layers=3 + ) # Final linear layer to map back to 1D space self.final_layer = nn.Linear(128, 1) - def forward(self, x): # Embedding the input x = self.embedding(x) @@ -56,6 +58,9 @@ class ColorTransformerModel(pl.LightningModule): # Final linear layer x = self.final_layer(x) + # Apply sigmoid activation to ensure output is in (0, 1) + x = torch.sigmoid(x) + return x def training_step(self, batch, batch_idx): @@ -66,11 +71,17 @@ class ColorTransformerModel(pl.LightningModule): inputs, outputs, alpha=self.hparams.alpha, - distinct_threshold=self.hparams.distinct_threshold, ) self.log("train_loss", loss) return loss def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return optimizer + optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2) + lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': lr_scheduler, + 'monitor': 'train_loss', # Specify the metric to monitor + } + } \ No newline at end of file diff --git a/search.py b/search.py index c51ae59..130b5dd 100644 --- a/search.py +++ b/search.py @@ -1,6 +1,7 @@ -from lightning_sdk import Studio, Machine from random import sample +from lightning_sdk import Machine, Studio + NUM_JOBS = 4 # reference to the current studio @@ -8,31 +9,32 @@ NUM_JOBS = 4 studio = Studio() # use the jobs plugin -studio.install_plugin('jobs') -job_plugin = studio.installed_plugins['jobs'] +studio.install_plugin("jobs") +job_plugin = studio.installed_plugins["jobs"] # do a sweep over learning rates # Define the ranges or sets of values for each hyperparameter alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9] -distinct_threshold_values = [0.5, 0.6, 0.7, 0.8] -learning_rate_values = [1e-6, 1-e5, 1e-4, 1e-3, 1e-2] +learning_rate_values = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] batch_size_values = [32, 64, 128] max_epochs_values = [10000] # Generate all possible combinations of hyperparameters -all_params = [(alpha, dt, lr, bs, me) for alpha in alpha_values - for dt in distinct_threshold_values - for lr in learning_rate_values - for bs in batch_size_values - for me in max_epochs_values] +all_params = [ + (alpha, lr, bs, me) + for alpha in alpha_values + for lr in learning_rate_values + for bs in batch_size_values + for me in max_epochs_values +] # perform random search with a limit -random_search_params = sample(search_params, NUM_JOBS) +search_params = sample(all_params, NUM_JOBS) # start all jobs on an A10G GPU with names containing an index -for idx, (a, thresh, lr, bs, max_epochs) in enumerate(search_params): - cmd = f'python main.py --alpha {a} -D {thresh} --lr {lr} --bs {bs} --max_epochs {max_epochs}' - job_name = f'color-exp-{idx}' +for idx, (a, lr, bs, me) in enumerate(search_params): + cmd = f"python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}" + job_name = f"color-exp-{idx}" job_plugin.run(cmd, machine=Machine.T4, name=job_name)