Browse Source

total overhaul of model

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
88c8cde9f6
  1. 6
      dataloader.py
  2. 77
      datamodule.py
  3. 5
      main.py
  4. 132
      model.py
  5. 28
      newmain.py

6
dataloader.py

@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, TensorDataset
from utils import extract_colors, preprocess_data
def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
def create_random_dataloader(N: int = 1e8, skip: bool = True, **kwargs):
rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32)
rgb_tensor = preprocess_data(rgb_tensor, skip=skip)
# Creating a dataset and data loader
@ -36,7 +36,9 @@ def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs):
if __name__ == "__main__":
batch_size = 4
train_dataloader = create_dataloader(N=1e6, batch_size=batch_size, shuffle=True)
train_dataloader = create_random_dataloader(
N=1e6, batch_size=batch_size, shuffle=True
)
print(len(train_dataloader.dataset))
train_dataloader_with_names = create_named_dataloader(
batch_size=batch_size, shuffle=True

77
datamodule.py

@ -0,0 +1,77 @@
import lightning as L
import torch
from matplotlib.colors import rgb_to_hsv
from torch.utils.data import DataLoader
from utils import extract_colors, preprocess_data
class ColorDataModule(L.LightningDataModule):
def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32):
super().__init__()
self.val_size = val_size
self.batch_size = batch_size
self.train_size = train_size
def prepare_data(self):
# no state. called from main process.
pass
def setup(self, stage: str):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True)
self.xkcd_colors = [
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
if self.train_size > 0:
train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
train_rgb = preprocess_data(train_rgb, skip=True)
self.train_colors = [
(
train_rgb[i],
torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(train_rgb))
]
val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32)
val_rgb = preprocess_data(val_rgb, skip=True)
self.random_colors = [
(
val_rgb[i],
torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32),
)
for i in range(len(val_rgb))
]
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.color_val = self.random_colors
if self.train_size > 0:
self.color_train = self.train_colors
else:
self.color_train = self.xkcd_colors
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.color_test = self.random_colors
if stage == "predict": # for visualizing
self.color_predict = self.xkcd_colors
def train_dataloader(self):
return DataLoader(self.color_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.color_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.color_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.color_predict, batch_size=self.batch_size)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
pass

5
main.py

@ -79,7 +79,7 @@ if __name__ == "__main__":
# named_data_loader also has grayscale extras. TODO: remove unnamed
train_dataloader = create_dataloader(
# N=1e5,
skip=False,
skip=True,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,
@ -90,6 +90,9 @@ if __name__ == "__main__":
learning_rate=args.lr,
batch_size=args.bs,
width=args.width,
bias=False,
transform="relu",
depth=1,
)
# Initialize model with parsed arguments

132
model.py

@ -1,4 +1,4 @@
import pytorch_lightning as pl
import lightning as L
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -6,118 +6,36 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss # noqa: F401
from utils import PURE_HSV, PURE_RGB
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Model layers
# self.layers = nn.Sequential(
# nn.Linear(5, 128, bias=False),
# nn.Linear(128, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 64, bias=False),
# nn.Linear(64, 128, bias=False),
# nn.Linear(128, 256, bias=False),
# nn.Linear(256, 128, bias=False),
# nn.ReLU(),
# nn.Linear(128, 1, bias=False),
# )
# def forward(self, x):
# x = self.layers(x)
# x = (torch.sin(x) + 1) / 2
# return x
# class ColorTransformerModel(pl.LightningModule):
# def __init__(self, params):
# super().__init__()
# self.save_hyperparameters(params)
# # Embedding layer to expand the input dimensions
# self.embedding = nn.Linear(3, 128, bias=False)
# # Transformer encoder-decoder
# encoder = nn.TransformerEncoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_encoder = nn.TransformerEncoder(
# encoder, num_layers=3
# )
# # lower dimensionality decoder
# decoder = nn.TransformerDecoderLayer(
# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3
# )
# self.transformer_decoder = nn.TransformerDecoder(
# decoder, num_layers=3
# )
# # Final linear layer to map back to 1D space
# self.final_layer = nn.Linear(128, 1, bias=False)
# def forward(self, x):
# # Embedding the input
# x = self.embedding(x)
# # Adjusting the shape for the transformer
# x = x.unsqueeze(1) # Adding a fake sequence dimension
# # Passing through the transformer
# x = self.transformer_encoder(x)
# # Passing through the decoder
# x = self.transformer_decoder(x, memory=x)
# # Reshape back to original shape
# x = x.squeeze(1)
# # Final linear layer
# x = self.final_layer(x)
# # Apply sigmoid activation to ensure output is in (0, 1)
# # x = torch.sigmoid(x)
# x = (torch.sin(x) + 1) / 2
# return x
class ColorTransformerModel(pl.LightningModule):
def __init__(self, params):
class ColorTransformerModel(L.LightningModule):
def __init__(
self,
transform: str = "relu",
width: int = 128,
depth: int = 1,
bias: bool = False,
):
super().__init__()
self.save_hyperparameters(params)
# self.a = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# self.b = nn.Sequential(
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 3, bias=False),
# nn.ReLU(),
# nn.Linear(3, 1, bias=False),
# nn.ReLU(),
# )
# Neural network layers
self.save_hyperparameters()
if self.hparams.transform.lower() == "tanh":
t = nn.Tanh
elif self.hparams.transform.lower() == "relu":
t = nn.ReLU
w = self.hparams.width
d = self.hparams.depth
bias = self.hparams.bias
midlayers = [nn.Linear(w, w, bias=bias), t()] * d
self.network = nn.Sequential(
nn.Linear(5, 64),
nn.Tanh(),
nn.Linear(64, self.hparams.width),
nn.Tanh(),
nn.Linear(self.hparams.width, 3),
nn.Tanh(),
nn.Linear(3, 1),
nn.Linear(3, w, bias=bias),
t(),
*midlayers,
nn.Linear(w, 3, bias=bias),
t(),
nn.Linear(3, 1, bias=bias),
)
def forward(self, x):
# Pass the input through the network
# a = self.a(x)
# b = self.b(x)
# a = torch.sigmoid(a)
# b = torch.sigmoid(b)
# x = torch.cat([x, a, b], dim=-1)
x = self.network(x)
# Circular mapping
# x = (torch.sin(x) + 1) / 2

28
newmain.py

@ -0,0 +1,28 @@
from lightning.pytorch.cli import LightningCLI
# from callbacks import SaveImageCallback
from datamodule import ColorDataModule
from model import ColorTransformerModel
def cli_main():
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841
# note: don't call fit!!
if __name__ == "__main__":
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block
# save_img_callback = SaveImageCallback(
# save_interval=0,
# final_dir="out",
# )
# trainer = pl.Trainer(
# callbacks=[save_img_callback],
# )
# data = ColorDataModule()
# trainer.fit(model, data)
Loading…
Cancel
Save