diff --git a/model.py b/model.py index e459869..b07bc25 100644 --- a/model.py +++ b/model.py @@ -14,8 +14,12 @@ class ColorTransformerModel(pl.LightningModule): # Model layers self.layers = nn.Sequential( nn.Linear(5, 128), + nn.Linear(128, 3), nn.ReLU(), - nn.Linear(128, 128), + nn.Linear(3, 64), + nn.Linear(64, 128), + nn.Linear(128, 256), + nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1), )