|
@ -86,15 +86,17 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
|
|
|
|
|
|
# Neural network layers |
|
|
# Neural network layers |
|
|
self.network = nn.Sequential( |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(5, 64), # Input layer |
|
|
nn.Linear(3, 16), |
|
|
nn.Tanh(), |
|
|
nn.ReLU(), |
|
|
nn.Linear(64, 128), |
|
|
nn.Linear(16, 16), |
|
|
nn.Tanh(), |
|
|
nn.ReLU(), |
|
|
|
|
|
nn.Linear(16, 128), |
|
|
|
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 128), |
|
|
nn.Linear(128, 128), |
|
|
nn.Tanh(), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 64), |
|
|
nn.Linear(128, 64), |
|
|
nn.Tanh(), |
|
|
nn.ReLU(), |
|
|
nn.Linear(64, 1), # Output layer |
|
|
nn.Linear(64, 1), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
@ -114,9 +116,10 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
outputs, |
|
|
outputs, |
|
|
) |
|
|
) |
|
|
alpha = self.hparams.alpha |
|
|
alpha = self.hparams.alpha |
|
|
loss = (p_loss + alpha * s_loss) / (1 + alpha) |
|
|
loss = p_loss + alpha * s_loss |
|
|
self.log("hp_metric", loss) |
|
|
self.log("hp_metric", loss) |
|
|
self.log("train_loss", loss) |
|
|
self.log("p_loss", p_loss) |
|
|
|
|
|
self.log("s_loss", s_loss) |
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
def configure_optimizers(self): |
|
@ -131,6 +134,6 @@ class ColorTransformerModel(pl.LightningModule): |
|
|
"optimizer": optimizer, |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"lr_scheduler": { |
|
|
"scheduler": lr_scheduler, |
|
|
"scheduler": lr_scheduler, |
|
|
"monitor": "train_loss", # Specify the metric to monitor |
|
|
"monitor": "hp_metric", # Specify the metric to monitor |
|
|
}, |
|
|
}, |
|
|
} |
|
|
} |
|
|