Browse Source

go back to simpler model

new-sep-loss
mm 11 months ago
parent
commit
64e18eb6cf
  1. 25
      check.py
  2. BIN
      color_128_0.3_1.00e-06.png
  3. 106
      model.py
  4. 12
      search.py

25
check.py

@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
import numpy as np
from dataloader import extract_colors
from model import ColorTransformerModel
name = "color_128_0.3_1.00e-06"
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
M = ColorTransformerModel.load_from_checkpoint(ckpt)
rgb_tensor, names = extract_colors()
preds = M(rgb_tensor)
rgb_values = rgb_tensor.detach().numpy()
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots(figsize=(10, 5))
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx])
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{name}.png", dpi=300)

BIN
color_128_0.3_1.00e-06.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

106
model.py

@ -3,24 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import enhanced_loss, weighted_loss from losses import enhanced_loss, weighted_loss # noqa: F401
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, distinct_threshold, learning_rate):
# super().__init__()
# self.save_hyperparameters()
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(3, 128),
# nn.ReLU(),
# nn.Linear(128, 128),
# nn.ReLU(),
# nn.Linear(128, 1),
# )
# def forward(self, x):
# return self.layers(x)
class ColorTransformerModel(pl.LightningModule): class ColorTransformerModel(pl.LightningModule):
@ -28,40 +11,59 @@ class ColorTransformerModel(pl.LightningModule):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
# Embedding layer to expand the input dimensions # Model layers
self.embedding = nn.Linear(3, 128) self.layers = nn.Sequential(
nn.Linear(3, 128),
# Transformer block nn.ReLU(),
transformer_layer = nn.TransformerEncoderLayer( nn.Linear(128, 128),
d_model=128, nhead=4, dim_feedforward=512, dropout=0.1 nn.ReLU(),
) nn.Linear(128, 1),
self.transformer_encoder = nn.TransformerEncoder(
transformer_layer, num_layers=3
) )
# Final linear layer to map back to 1D space
self.final_layer = nn.Linear(128, 1)
def forward(self, x): def forward(self, x):
# Embedding the input x = self.layers(x)
x = self.embedding(x) x = torch.sigmoid(x)
return x
# Adjusting the shape for the transformer # class ColorTransformerModel(pl.LightningModule):
x = x.unsqueeze(1) # Adding a fake sequence dimension # def __init__(self, alpha, learning_rate):
# super().__init__()
# self.save_hyperparameters()
# Passing through the transformer # # Embedding layer to expand the input dimensions
x = self.transformer_encoder(x) # self.embedding = nn.Linear(3, 128)
# Reshape back to original shape # # Transformer block
x = x.squeeze(1) # transformer_layer = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
# )
# self.transformer_encoder = nn.TransformerEncoder(
# transformer_layer, num_layers=3
# )
# Final linear layer # # Final linear layer to map back to 1D space
x = self.final_layer(x) # self.final_layer = nn.Linear(128, 1)
# Apply sigmoid activation to ensure output is in (0, 1) # def forward(self, x):
x = torch.sigmoid(x) # # Embedding the input
# x = self.embedding(x)
return x # # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# x = torch.sigmoid(x)
# return x
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
@ -76,12 +78,16 @@ class ColorTransformerModel(pl.LightningModule):
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2) optimizer = torch.optim.AdamW(
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2
)
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=10, verbose=True
)
return { return {
'optimizer': optimizer, "optimizer": optimizer,
'lr_scheduler': { "lr_scheduler": {
'scheduler': lr_scheduler, "scheduler": lr_scheduler,
'monitor': 'train_loss', # Specify the metric to monitor "monitor": "train_loss", # Specify the metric to monitor
} },
} }

12
search.py

@ -2,7 +2,7 @@ from random import sample
from lightning_sdk import Machine, Studio from lightning_sdk import Machine, Studio
NUM_JOBS = 4 NUM_JOBS = 21
# reference to the current studio # reference to the current studio
# if you run outside of Lightning, you can pass the Studio name # if you run outside of Lightning, you can pass the Studio name
@ -18,7 +18,7 @@ job_plugin = studio.installed_plugins["jobs"]
alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9] alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9]
learning_rate_values = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] learning_rate_values = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
batch_size_values = [32, 64, 128] batch_size_values = [32, 64, 128]
max_epochs_values = [10000] max_epochs_values = [1000]
# Generate all possible combinations of hyperparameters # Generate all possible combinations of hyperparameters
all_params = [ all_params = [
@ -33,8 +33,8 @@ all_params = [
# perform random search with a limit # perform random search with a limit
search_params = sample(all_params, NUM_JOBS) search_params = sample(all_params, NUM_JOBS)
# start all jobs on an A10G GPU with names containing an index for idx, params in enumerate(search_params):
for idx, (a, lr, bs, me) in enumerate(search_params): a, lr, bs, me = params
cmd = f"python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}" cmd = f"cd ~/colors && python main.py --alpha {a} --lr {lr} --bs {bs} --max_epochs {me}"
job_name = f"color-exp-{idx}" job_name = f"color_{bs}_{a}_{lr:2.2e}"
job_plugin.run(cmd, machine=Machine.T4, name=job_name) job_plugin.run(cmd, machine=Machine.T4, name=job_name)

Loading…
Cancel
Save