Browse Source

diff arch, optimizer

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
4d8bcc8715
  1. 6
      check.py
  2. 2
      makefile
  3. 149
      model.py
  4. 7
      utils.py

6
check.py

@ -86,7 +86,11 @@ if __name__ == "__main__":
name = f"out/v{v}" name = f"out/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)[-1] ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}") print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name) create_circle(ckpt, fname=name)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False) # make_image(ckpt, fname=name + "b", color=False)

2
makefile

@ -4,7 +4,7 @@ lint:
flake8 --ignore E501,W503 . flake8 --ignore E501,W503 .
test: test:
python main.py --alpha 2 --lr 2e-4 --max_epochs 200 python main.py --alpha 4 --lr 2e-4 --max_epochs 200
search: search:
python search.py python search.py

149
model.py

@ -5,70 +5,106 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss from losses import calculate_separation_loss, preservation_loss
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(5, 128, bias=False),
# nn.Linear(128, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 64, bias=False),
# nn.Linear(64, 128, bias=False),
# nn.Linear(128, 256, bias=False),
# nn.Linear(256, 128, bias=False),
# nn.ReLU(),
# nn.Linear(128, 1, bias=False),
# )
# def forward(self, x):
# x = self.layers(x)
# x = (torch.sin(x) + 1) / 2
# return x
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Embedding layer to expand the input dimensions
# self.embedding = nn.Linear(3, 128, bias=False)
# # Transformer encoder-decoder
# encoder = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_encoder = nn.TransformerEncoder(
# encoder, num_layers=3
# )
# # lower dimensionality decoder
# decoder = nn.TransformerDecoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_decoder = nn.TransformerDecoder(
# decoder, num_layers=3
# )
# # Final linear layer to map back to 1D space
# self.final_layer = nn.Linear(128, 1, bias=False)
# def forward(self, x):
# # Embedding the input
# x = self.embedding(x)
# # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Passing through the decoder
# x = self.transformer_decoder(x, memory=x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# # x = torch.sigmoid(x)
# x = (torch.sin(x) + 1) / 2
# return x
class ColorTransformerModel(pl.LightningModule): class ColorTransformerModel(pl.LightningModule):
def __init__(self, params): def __init__(self, params):
super().__init__() super().__init__()
self.save_hyperparameters(params) self.save_hyperparameters(params)
# Model layers # Neural network layers
self.layers = nn.Sequential( self.network = nn.Sequential(
nn.Linear(5, 128, bias=False), nn.Linear(5, 64), # Input layer
nn.Linear(128, 3, bias=False), nn.Tanh(),
nn.ReLU(), nn.Linear(64, 128),
nn.Linear(3, 64, bias=False), nn.Tanh(),
nn.Linear(64, 128, bias=False), nn.Linear(128, 128),
nn.Linear(128, 256, bias=False), nn.Tanh(),
nn.Linear(256, 128, bias=False), nn.Linear(128, 64),
nn.ReLU(), nn.Tanh(),
nn.Linear(128, 1, bias=False), nn.Linear(64, 1), # Output layer
) )
def forward(self, x): def forward(self, x):
x = self.layers(x) # Pass the input through the network
x = (torch.sin(x) + 1) / 2 x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2
x = torch.sigmoid(x)
return x return x
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, alpha, learning_rate):
# super().__init__()
# self.save_hyperparameters()
# # Embedding layer to expand the input dimensions
# self.embedding = nn.Linear(3, 128)
# # Transformer block
# transformer_layer = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.1
# )
# self.transformer_encoder = nn.TransformerEncoder(
# transformer_layer, num_layers=3
# )
# # Final linear layer to map back to 1D space
# self.final_layer = nn.Linear(128, 1)
# def forward(self, x):
# # Embedding the input
# x = self.embedding(x)
# # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# x = torch.sigmoid(x)
# return x
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
@ -84,11 +120,12 @@ class ColorTransformerModel(pl.LightningModule):
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.AdamW( optimizer = torch.optim.SGD(
self.parameters(), lr=self.hparams.learning_rate, self.parameters(),
lr=self.hparams.learning_rate,
) )
lr_scheduler = ReduceLROnPlateau( lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=10, verbose=True optimizer, mode="min", factor=0.05, patience=10, cooldown=20, verbose=True
) )
return { return {
"optimizer": optimizer, "optimizer": optimizer,

7
utils.py

@ -1,9 +1,9 @@
import torch import torch
def preprocess_data(data): def preprocess_data(data, skip=False):
# Assuming 'data' is a tensor of shape [n_samples, 3] # Assuming 'data' is a tensor of shape [n_samples, 3]
if not skip:
# Compute argmin and argmax for each row # Compute argmin and argmax for each row
argmin_values = torch.argmin(data, dim=1, keepdim=True).float() argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
argmax_values = torch.argmax(data, dim=1, keepdim=True).float() argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
@ -15,7 +15,8 @@ def preprocess_data(data):
# Concatenate the argmin and argmax values to the original data # Concatenate the argmin and argmax values to the original data
new_data = torch.cat((data, argmin_values, argmax_values), dim=1) new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
else:
new_data = data
return new_data return new_data

Loading…
Cancel
Save