diff --git a/dataloader.py b/dataloader.py index 63120ac..75a779b 100644 --- a/dataloader.py +++ b/dataloader.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, TensorDataset from utils import extract_colors, preprocess_data -def create_dataloader(N: int = 1e8, skip: bool = True, **kwargs): +def create_random_dataloader(N: int = 1e8, skip: bool = True, **kwargs): rgb_tensor = torch.rand((int(N), 3), dtype=torch.float32) rgb_tensor = preprocess_data(rgb_tensor, skip=skip) # Creating a dataset and data loader @@ -36,7 +36,9 @@ def create_named_dataloader(N: int = 0, skip: bool = True, **kwargs): if __name__ == "__main__": batch_size = 4 - train_dataloader = create_dataloader(N=1e6, batch_size=batch_size, shuffle=True) + train_dataloader = create_random_dataloader( + N=1e6, batch_size=batch_size, shuffle=True + ) print(len(train_dataloader.dataset)) train_dataloader_with_names = create_named_dataloader( batch_size=batch_size, shuffle=True diff --git a/datamodule.py b/datamodule.py new file mode 100644 index 0000000..5d9b3ba --- /dev/null +++ b/datamodule.py @@ -0,0 +1,77 @@ +import lightning as L +import torch +from matplotlib.colors import rgb_to_hsv +from torch.utils.data import DataLoader + +from utils import extract_colors, preprocess_data + + +class ColorDataModule(L.LightningDataModule): + def __init__(self, val_size: int = 10_000, train_size=0, batch_size: int = 32): + super().__init__() + self.val_size = val_size + self.batch_size = batch_size + self.train_size = train_size + + def prepare_data(self): + # no state. called from main process. + pass + + def setup(self, stage: str): + rgb_tensor, xkcd_color_names = extract_colors() + rgb_tensor = preprocess_data(rgb_tensor, skip=True) + self.xkcd_colors = [ + (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) + for i in range(len(rgb_tensor)) + ] + if self.train_size > 0: + train_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) + train_rgb = preprocess_data(train_rgb, skip=True) + self.train_colors = [ + ( + train_rgb[i], + torch.tensor(rgb_to_hsv(train_rgb[i])[:, 0], dtype=torch.float32), + ) + for i in range(len(train_rgb)) + ] + + val_rgb = torch.rand((int(self.val_size), 3), dtype=torch.float32) + val_rgb = preprocess_data(val_rgb, skip=True) + self.random_colors = [ + ( + val_rgb[i], + torch.tensor(rgb_to_hsv(val_rgb[i])[:, 0], dtype=torch.float32), + ) + for i in range(len(val_rgb)) + ] + + # Assign train/val datasets for use in dataloaders + if stage == "fit": + self.color_val = self.random_colors + if self.train_size > 0: + self.color_train = self.train_colors + else: + self.color_train = self.xkcd_colors + + # Assign test dataset for use in dataloader(s) + if stage == "test": + self.color_test = self.random_colors + + if stage == "predict": # for visualizing + self.color_predict = self.xkcd_colors + + def train_dataloader(self): + return DataLoader(self.color_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.color_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.color_test, batch_size=self.batch_size) + + def predict_dataloader(self): + return DataLoader(self.color_predict, batch_size=self.batch_size) + + def teardown(self, stage: str): + # Used to clean-up when the run is finished + pass diff --git a/main.py b/main.py index 2fe83f9..aa71a40 100644 --- a/main.py +++ b/main.py @@ -79,7 +79,7 @@ if __name__ == "__main__": # named_data_loader also has grayscale extras. TODO: remove unnamed train_dataloader = create_dataloader( # N=1e5, - skip=False, + skip=True, batch_size=args.bs, shuffle=True, num_workers=args.num_workers, @@ -90,6 +90,9 @@ if __name__ == "__main__": learning_rate=args.lr, batch_size=args.bs, width=args.width, + bias=False, + transform="relu", + depth=1, ) # Initialize model with parsed arguments diff --git a/model.py b/model.py index e24415f..67fd130 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,4 @@ -import pytorch_lightning as pl +import lightning as L import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -6,118 +6,36 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from losses import calculate_separation_loss, preservation_loss # noqa: F401 from utils import PURE_HSV, PURE_RGB -# class ColorTransformerModel(pl.LightningModule): -# def __init__(self, params): -# super().__init__() -# self.save_hyperparameters(params) -# # Model layers -# self.layers = nn.Sequential( -# nn.Linear(5, 128, bias=False), -# nn.Linear(128, 3, bias=False), -# nn.ReLU(), -# nn.Linear(3, 64, bias=False), -# nn.Linear(64, 128, bias=False), -# nn.Linear(128, 256, bias=False), -# nn.Linear(256, 128, bias=False), -# nn.ReLU(), -# nn.Linear(128, 1, bias=False), -# ) - -# def forward(self, x): -# x = self.layers(x) -# x = (torch.sin(x) + 1) / 2 -# return x - -# class ColorTransformerModel(pl.LightningModule): -# def __init__(self, params): -# super().__init__() -# self.save_hyperparameters(params) - -# # Embedding layer to expand the input dimensions -# self.embedding = nn.Linear(3, 128, bias=False) - -# # Transformer encoder-decoder -# encoder = nn.TransformerEncoderLayer( -# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3 -# ) -# self.transformer_encoder = nn.TransformerEncoder( -# encoder, num_layers=3 -# ) -# # lower dimensionality decoder -# decoder = nn.TransformerDecoderLayer( -# d_model=128, nhead=4, dim_feedforward=512, dropout=0.3 -# ) -# self.transformer_decoder = nn.TransformerDecoder( -# decoder, num_layers=3 -# ) - -# # Final linear layer to map back to 1D space -# self.final_layer = nn.Linear(128, 1, bias=False) - -# def forward(self, x): -# # Embedding the input -# x = self.embedding(x) - -# # Adjusting the shape for the transformer -# x = x.unsqueeze(1) # Adding a fake sequence dimension - -# # Passing through the transformer -# x = self.transformer_encoder(x) - -# # Passing through the decoder -# x = self.transformer_decoder(x, memory=x) - -# # Reshape back to original shape -# x = x.squeeze(1) - -# # Final linear layer -# x = self.final_layer(x) - -# # Apply sigmoid activation to ensure output is in (0, 1) -# # x = torch.sigmoid(x) -# x = (torch.sin(x) + 1) / 2 -# return x - - -class ColorTransformerModel(pl.LightningModule): - def __init__(self, params): +class ColorTransformerModel(L.LightningModule): + def __init__( + self, + transform: str = "relu", + width: int = 128, + depth: int = 1, + bias: bool = False, + ): super().__init__() - self.save_hyperparameters(params) - # self.a = nn.Sequential( - # nn.Linear(3, 3, bias=False), - # nn.ReLU(), - # nn.Linear(3, 3, bias=False), - # nn.ReLU(), - # nn.Linear(3, 1, bias=False), - # nn.ReLU(), - # ) - # self.b = nn.Sequential( - # nn.Linear(3, 3, bias=False), - # nn.ReLU(), - # nn.Linear(3, 3, bias=False), - # nn.ReLU(), - # nn.Linear(3, 1, bias=False), - # nn.ReLU(), - # ) - # Neural network layers + self.save_hyperparameters() + if self.hparams.transform.lower() == "tanh": + t = nn.Tanh + elif self.hparams.transform.lower() == "relu": + t = nn.ReLU + + w = self.hparams.width + d = self.hparams.depth + bias = self.hparams.bias + midlayers = [nn.Linear(w, w, bias=bias), t()] * d self.network = nn.Sequential( - nn.Linear(5, 64), - nn.Tanh(), - nn.Linear(64, self.hparams.width), - nn.Tanh(), - nn.Linear(self.hparams.width, 3), - nn.Tanh(), - nn.Linear(3, 1), + nn.Linear(3, w, bias=bias), + t(), + *midlayers, + nn.Linear(w, 3, bias=bias), + t(), + nn.Linear(3, 1, bias=bias), ) def forward(self, x): - # Pass the input through the network - # a = self.a(x) - # b = self.b(x) - # a = torch.sigmoid(a) - # b = torch.sigmoid(b) - # x = torch.cat([x, a, b], dim=-1) x = self.network(x) # Circular mapping # x = (torch.sin(x) + 1) / 2 diff --git a/newmain.py b/newmain.py new file mode 100644 index 0000000..eec05a7 --- /dev/null +++ b/newmain.py @@ -0,0 +1,28 @@ +from lightning.pytorch.cli import LightningCLI + +# from callbacks import SaveImageCallback +from datamodule import ColorDataModule +from model import ColorTransformerModel + + +def cli_main(): + cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841 + # note: don't call fit!! + + +if __name__ == "__main__": + cli_main() + # note: it is good practice to implement the CLI in a function and call it in the main if block + + +# save_img_callback = SaveImageCallback( +# save_interval=0, +# final_dir="out", +# ) + +# trainer = pl.Trainer( +# callbacks=[save_img_callback], +# ) + +# data = ColorDataModule() +# trainer.fit(model, data)