@ -14,8 +14,12 @@ class ColorTransformerModel(pl.LightningModule):
# Model layers
self.layers = nn.Sequential(
nn.Linear(5, 128),
nn.Linear(128, 3),
nn.ReLU(),
nn.Linear(128, 128),
nn.Linear(3, 64),
nn.Linear(64, 128),
nn.Linear(128, 256),
nn.Linear(256, 128),
nn.Linear(128, 1),
)