Browse Source

slight refactors

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
06a05cdf65
  1. 8
      callbacks.py
  2. 2
      check.py
  3. 6
      main.py
  4. 23
      model.py
  5. 2
      out/index.html
  6. 2
      utils.py

8
callbacks.py

@ -36,3 +36,11 @@ class SaveImageCallback(pl.Callback):
fname = Path(f"{self.final_dir}") / Path(f"v{version}") fname = Path(f"{self.final_dir}") / Path(f"v{version}")
pl_module.eval() pl_module.eval()
create_circle(pl_module, fname=fname) create_circle(pl_module, fname=fname)
if self.save_interval > 0:
import os
log_dir = str(Path(pl_module.trainer.logger.log_dir))
fps = 12
os.system(
f'ffmpeg -i {log_dir}/e%04d.png -c:v libx264 -vf "fps={fps},format=yuv420p,pad=ceil(iw/2)*2:ceil(ih/2)*2" {log_dir}/a{version}.mp4'
)

2
check.py

@ -57,7 +57,7 @@ def create_circle(ckpt: str, fname: str):
N = len(colors) N = len(colors)
# Create a plot with these hues in a circle # Create a plot with these hues in a circle
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) fig, ax = plt.subplots(figsize=(3, 3), subplot_kw=dict(polar=True))
# Each wedge in the circle # Each wedge in the circle
theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2 theta = np.linspace(0, 2 * np.pi, N + 1) + np.pi / 2

6
main.py

@ -46,14 +46,14 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
early_stop_callback = EarlyStopping( early_stop_callback = EarlyStopping(
monitor="hp_metric", # Metric to monitor monitor="hp_metric", # Metric to monitor
min_delta=0, # Minimum change in the monitored quantity to qualify as an improvement min_delta=1e-5, # Minimum change in the monitored quantity to qualify as an improvement
patience=50, # Number of epochs with no improvement after which training will be stopped patience=12, # Number of epochs with no improvement after which training will be stopped
mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it. mode="min", # Mode can be either 'min' for minimizing the monitored quantity or 'max' for maximizing it.
verbose=True, verbose=True,
) )
save_img_callback = SaveImageCallback( save_img_callback = SaveImageCallback(
save_interval=0, save_interval=1,
final_dir="out", final_dir="out",
) )

23
model.py

@ -86,15 +86,17 @@ class ColorTransformerModel(pl.LightningModule):
# Neural network layers # Neural network layers
self.network = nn.Sequential( self.network = nn.Sequential(
nn.Linear(5, 64), # Input layer nn.Linear(3, 16),
nn.Tanh(), nn.ReLU(),
nn.Linear(64, 128), nn.Linear(16, 16),
nn.Tanh(), nn.ReLU(),
nn.Linear(16, 128),
nn.ReLU(),
nn.Linear(128, 128), nn.Linear(128, 128),
nn.Tanh(), nn.ReLU(),
nn.Linear(128, 64), nn.Linear(128, 64),
nn.Tanh(), nn.ReLU(),
nn.Linear(64, 1), # Output layer nn.Linear(64, 1),
) )
def forward(self, x): def forward(self, x):
@ -114,9 +116,10 @@ class ColorTransformerModel(pl.LightningModule):
outputs, outputs,
) )
alpha = self.hparams.alpha alpha = self.hparams.alpha
loss = (p_loss + alpha * s_loss) / (1 + alpha) loss = p_loss + alpha * s_loss
self.log("hp_metric", loss) self.log("hp_metric", loss)
self.log("train_loss", loss) self.log("p_loss", p_loss)
self.log("s_loss", s_loss)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
@ -131,6 +134,6 @@ class ColorTransformerModel(pl.LightningModule):
"optimizer": optimizer, "optimizer": optimizer,
"lr_scheduler": { "lr_scheduler": {
"scheduler": lr_scheduler, "scheduler": lr_scheduler,
"monitor": "train_loss", # Specify the metric to monitor "monitor": "hp_metric", # Specify the metric to monitor
}, },
} }

2
out/index.html

@ -72,7 +72,7 @@
function loadImages() { function loadImages() {
var gallery = document.getElementById('gallery'); var gallery = document.getElementById('gallery');
for (var i = 1; i <= 100; i++) { // Changed from i <= 100 to i < 100 for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100
let imageName = 'v' + i + '.png'; let imageName = 'v' + i + '.png';
let img = document.createElement('img'); let img = document.createElement('img');
img.src = imageName; img.src = imageName;

2
utils.py

@ -1,7 +1,7 @@
import torch import torch
def preprocess_data(data, skip=False): def preprocess_data(data, skip=True):
# Assuming 'data' is a tensor of shape [n_samples, 3] # Assuming 'data' is a tensor of shape [n_samples, 3]
if not skip: if not skip:
# Compute argmin and argmax for each row # Compute argmin and argmax for each row

Loading…
Cancel
Save