diff --git a/model.py b/model.py index 7486af3..97909ba 100644 --- a/model.py +++ b/model.py @@ -14,6 +14,7 @@ class ColorTransformerModel(L.LightningModule): depth: int = 1, bias: bool = False, alpha: float = 0, + lr: float = 0.01, ): super().__init__() self.save_hyperparameters() @@ -84,7 +85,7 @@ class ColorTransformerModel(L.LightningModule): def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), - lr=0.1, + lr=self.hparams.lr, ) lr_scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.05, patience=5, cooldown=10, verbose=True diff --git a/newsearch.py b/newsearch.py index 2adc493..7630e1e 100644 --- a/newsearch.py +++ b/newsearch.py @@ -33,7 +33,7 @@ alpha_values = [1.0] widths, depths = [512], [4] batch_size_values = [256] -max_epochs_values = [20] +max_epochs_values = [50] seeds = list(range(21, 1992)) optimizers = [ # "Adagrad",