Browse Source

remove s loss entirely

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
1e818aa977
  1. 14
      model.py
  2. 2
      newsearch.py

14
model.py

@ -4,7 +4,6 @@ import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import calculate_separation_loss, preservation_loss # noqa: F401
from utils import PURE_HSV, PURE_RGB
class ColorTransformerModel(L.LightningModule):
@ -46,23 +45,14 @@ class ColorTransformerModel(L.LightningModule):
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# s_loss = calculate_separation_loss(model=self)
# preserve distance to pure R, G, B. this acts kind of like labeled data.
s_loss = preservation_loss(
inputs,
outputs,
target_inputs=PURE_RGB,
target_outputs=PURE_HSV,
)
p_loss = preservation_loss(
inputs,
outputs,
)
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
# alpha = self.hparams.alpha # TODO: decide what to do with this...
loss = p_loss
self.log("hp_metric", loss)
self.log("p_loss", p_loss)
self.log("s_loss", s_loss)
return loss
def validation_step(self, batch):

2
newsearch.py

@ -21,7 +21,7 @@ NUM_JOBS = 100
# alpha_values = list(np.round(np.linspace(2, 4, 21), 4))
# learning_rate_values = list(np.round(np.logspace(-5, -3, 21), 5))
learning_rate_values = [1e-2]
alpha_values = [0, 1, 2]
alpha_values = [0]
widths = [2**k for k in range(4, 15)]
# learning_rate_values = [5e-4]
batch_size_values = [256]

Loading…
Cancel
Save