diff --git a/check.py b/check.py new file mode 100644 index 0000000..ddf855b --- /dev/null +++ b/check.py @@ -0,0 +1,25 @@ +import matplotlib.pyplot as plt +import numpy as np + +from dataloader import extract_colors +from model import ColorTransformerModel + +name = "color_128_0.3_1.00e-06" +ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" +M = ColorTransformerModel.load_from_checkpoint(ckpt) + +rgb_tensor, names = extract_colors() +preds = M(rgb_tensor) +rgb_values = rgb_tensor.detach().numpy() +sorted_inds = np.argsort(preds.detach().numpy().ravel()) + + +fig, ax = plt.subplots(figsize=(10, 5)) +for i in range(len(sorted_inds)): + idx = sorted_inds[i] + color = rgb_values[idx] + ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx]) + ax.axis("off") + # ax.axis("square") + +plt.savefig(f"{name}.png", dpi=300) diff --git a/color_128_0.3_1.00e-06.png b/color_128_0.3_1.00e-06.png new file mode 100644 index 0000000..447679e Binary files /dev/null and b/color_128_0.3_1.00e-06.png differ diff --git a/model.py b/model.py index 2944ead..fd7fc92 100644 --- a/model.py +++ b/model.py @@ -3,24 +3,7 @@ import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from losses import enhanced_loss, weighted_loss - -# class ColorTransformerModel(pl.LightningModule): -# def __init__(self, alpha, distinct_threshold, learning_rate): -# super().__init__() -# self.save_hyperparameters() - -# # Model layers -# self.layers = nn.Sequential( -# nn.Linear(3, 128), -# nn.ReLU(), -# nn.Linear(128, 128), -# nn.ReLU(), -# nn.Linear(128, 1), -# ) - -# def forward(self, x): -# return self.layers(x) +from losses import enhanced_loss, weighted_loss # noqa: F401 class ColorTransformerModel(pl.LightningModule): @@ -28,40 +11,59 @@ class ColorTransformerModel(pl.LightningModule): super().__init__() self.save_hyperparameters() - # Embedding layer to expand the input dimensions - self.embedding = nn.Linear(3, 128) - - # Transformer block - transformer_layer = nn.TransformerEncoderLayer( - d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 - ) - self.transformer_encoder = nn.TransformerEncoder( - transformer_layer, num_layers=3 + # Model layers + self.layers = nn.Sequential( + nn.Linear(3, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 1), ) - # Final linear layer to map back to 1D space - self.final_layer = nn.Linear(128, 1) - def forward(self, x): - # Embedding the input - x = self.embedding(x) + x = self.layers(x) + x = torch.sigmoid(x) + return x - # Adjusting the shape for the transformer - x = x.unsqueeze(1) # Adding a fake sequence dimension + # class ColorTransformerModel(pl.LightningModule): + # def __init__(self, alpha, learning_rate): + # super().__init__() + # self.save_hyperparameters() - # Passing through the transformer - x = self.transformer_encoder(x) + # # Embedding layer to expand the input dimensions + # self.embedding = nn.Linear(3, 128) - # Reshape back to original shape - x = x.squeeze(1) + # # Transformer block + # transformer_layer = nn.TransformerEncoderLayer( + # d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 + # ) + # self.transformer_encoder = nn.TransformerEncoder( + # transformer_layer, num_layers=3 + # ) - # Final linear layer - x = self.final_layer(x) + # # Final linear layer to map back to 1D space + # self.final_layer = nn.Linear(128, 1) - # Apply sigmoid activation to ensure output is in (0, 1) - x = torch.sigmoid(x) + # def forward(self, x): + # # Embedding the input + # x = self.embedding(x) - return x + # # Adjusting the shape for the transformer + # x = x.unsqueeze(1) # Adding a fake sequence dimension + + # # Passing through the transformer + # x = self.transformer_encoder(x) + + # # Reshape back to original shape + # x = x.squeeze(1) + + # # Final linear layer + # x = self.final_layer(x) + + # # Apply sigmoid activation to ensure output is in (0, 1) + # x = torch.sigmoid(x) + + # return x def training_step(self, batch, batch_idx): inputs, labels = batch # x are the RGB inputs, labels are the strings @@ -76,12 +78,16 @@ class ColorTransformerModel(pl.LightningModule): return loss def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2) - lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) + optimizer = torch.optim.AdamW( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2 + ) + lr_scheduler = ReduceLROnPlateau( + optimizer, mode="min", factor=0.1, patience=10, verbose=True + ) return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': lr_scheduler, - 'monitor': 'train_loss', # Specify the metric to monitor - } - } \ No newline at end of file + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "train_loss", # Specify the metric to monitor + }, + } diff --git a/search.py b/search.py index 130b5dd..a9e5f26 100644 --- a/search.py +++ b/search.py @@ -2,7 +2,7 @@ from random import sample from lightning_sdk import Machine, Studio -NUM_JOBS = 4 +NUM_JOBS = 21 # reference to the current studio # if you run outside of Lightning, you can pass the Studio name @@ -18,7 +18,7 @@ job_plugin = studio.installed_plugins["jobs"] alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9] learning_rate_values = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] batch_size_values = [32, 64, 128] -max_epochs_values = [10000] +max_epochs_values = [1000] # Generate all possible combinations of hyperparameters all_params = [ @@ -33,8 +33,8 @@ all_params = [ # perform random search with a limit search_params = sample(all_params, NUM_JOBS) -# start all jobs on an A10G GPU with names containing an index -for idx, (a, lr, bs, me) in enumerate(search_params): - cmd = f"python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}" - job_name = f"color-exp-{idx}" +for idx, params in enumerate(search_params): + a, lr, bs, me = params + cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}" + job_name = f"color_{bs}_{a}_{lr:2.2e}" job_plugin.run(cmd, machine=Machine.T4, name=job_name)