|
@ -14,6 +14,7 @@ class ColorTransformerModel(L.LightningModule): |
|
|
depth: int = 1, |
|
|
depth: int = 1, |
|
|
bias: bool = False, |
|
|
bias: bool = False, |
|
|
alpha: float = 0, |
|
|
alpha: float = 0, |
|
|
|
|
|
lr: float = 0.01, |
|
|
): |
|
|
): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.save_hyperparameters() |
|
@ -84,7 +85,7 @@ class ColorTransformerModel(L.LightningModule): |
|
|
def configure_optimizers(self): |
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.SGD( |
|
|
optimizer = torch.optim.SGD( |
|
|
self.parameters(), |
|
|
self.parameters(), |
|
|
lr=0.1, |
|
|
lr=self.hparams.lr, |
|
|
) |
|
|
) |
|
|
lr_scheduler = ReduceLROnPlateau( |
|
|
lr_scheduler = ReduceLROnPlateau( |
|
|
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True |
|
|
optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True |
|
|