Compare commits
2 Commits
a9e772f34e
...
5687f30818
Author | SHA1 | Date | |
---|---|---|---|
|
5687f30818 | ||
|
5def982f12 |
@ -43,7 +43,9 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
|
||||
transformed_norm = circle_norm(outputs, target_outputs) * 2
|
||||
|
||||
diff = torch.pow(rgb_norm - transformed_norm, 2)
|
||||
N = torch.count_nonzero(rgb_norm)
|
||||
N = len(outputs)
|
||||
N = (N * (N - 1)) / 2
|
||||
# N = torch.count_nonzero(rgb_norm)
|
||||
return torch.sum(diff) / N
|
||||
|
||||
|
||||
|
7
model.py
7
model.py
@ -31,7 +31,8 @@ class ColorTransformerModel(L.LightningModule):
|
||||
d = self.hparams.depth
|
||||
bias = self.hparams.bias
|
||||
if self.hparams.loop:
|
||||
midlayers = [nn.Linear(w, w, bias=bias), t()] * d
|
||||
midlayers = []
|
||||
midlayers += [nn.Linear(w, w, bias=bias), t()] * d
|
||||
else:
|
||||
midlayers = sum(
|
||||
[
|
||||
@ -64,8 +65,8 @@ class ColorTransformerModel(L.LightningModule):
|
||||
p_loss = preservation_loss(
|
||||
inputs,
|
||||
outputs,
|
||||
target_inputs=rgb_tensor,
|
||||
target_outputs=self.forward(rgb_tensor),
|
||||
# target_inputs=rgb_tensor,
|
||||
# target_outputs=self.forward(rgb_tensor),
|
||||
)
|
||||
alpha = self.hparams.alpha
|
||||
|
||||
|
10
newsearch.py
10
newsearch.py
@ -27,19 +27,19 @@ learning_rate_values = [1e-3]
|
||||
# learning_rate_values = [5e-4]
|
||||
|
||||
# alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
|
||||
alpha_values = [0.9]
|
||||
alpha_values = [0]
|
||||
# widths = [2**k for k in range(4, 13)]
|
||||
# depths = [1, 2, 4, 8, 16]
|
||||
widths, depths = [512], [8]
|
||||
widths, depths = [512], [4]
|
||||
|
||||
batch_size_values = [256]
|
||||
max_epochs_values = [100]
|
||||
seeds = list(range(21, 1992))
|
||||
optimizers = [
|
||||
# "Adagrad",
|
||||
# "Adam",
|
||||
"Adam",
|
||||
# "SGD",
|
||||
"AdamW",
|
||||
# "AdamW",
|
||||
# "LBFGS",
|
||||
# "RAdam",
|
||||
# "RMSprop",
|
||||
@ -80,7 +80,7 @@ python newmain.py fit \
|
||||
--model.depth {d} \
|
||||
--model.bias true \
|
||||
--model.loop true \
|
||||
--model.transform relu \
|
||||
--model.transform tanh \
|
||||
--trainer.min_epochs 10 \
|
||||
--trainer.max_epochs {me} \
|
||||
--trainer.log_every_n_steps 3 \
|
||||
|
Loading…
Reference in New Issue
Block a user