|
@ -13,15 +13,15 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
|
|
|
|
|
|
# Model layers |
|
|
# Model layers |
|
|
self.layers = nn.Sequential( |
|
|
self.layers = nn.Sequential( |
|
|
nn.Linear(5, 128), |
|
|
nn.Linear(5, 128, bias=False), |
|
|
nn.Linear(128, 3), |
|
|
nn.Linear(128, 3, bias=False), |
|
|
nn.ReLU(), |
|
|
nn.ReLU(), |
|
|
nn.Linear(3, 64), |
|
|
nn.Linear(3, 64, bias=False), |
|
|
nn.Linear(64, 128), |
|
|
nn.Linear(64, 128, bias=False), |
|
|
nn.Linear(128, 256), |
|
|
nn.Linear(128, 256, bias=False), |
|
|
nn.Linear(256, 128), |
|
|
nn.Linear(256, 128, bias=False), |
|
|
nn.ReLU(), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 1), |
|
|
nn.Linear(128, 1, bias=False), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
@ -85,7 +85,7 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.AdamW( |
|
|
optimizer = torch.optim.AdamW( |
|
|
self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2 |
|
|
self.parameters(), lr=self.hparams.learning_rate, |
|
|
) |
|
|
) |
|
|
lr_scheduler = ReduceLROnPlateau( |
|
|
lr_scheduler = ReduceLROnPlateau( |
|
|
optimizer, mode="min", factor=0.1, patience=10, verbose=True |
|
|
optimizer, mode="min", factor=0.1, patience=10, verbose=True |
|
|