Compare commits
3 Commits
4342a54cc8
...
d318480b7c
Author | SHA1 | Date | |
---|---|---|---|
|
d318480b7c | ||
|
1c116f3f12 | ||
|
c7ffd09fb4 |
4
model.py
4
model.py
@ -58,8 +58,8 @@ class ColorTransformerModel(L.LightningModule):
|
||||
alpha = self.hparams.alpha
|
||||
|
||||
# N = len(outputs)
|
||||
distance = circle_norm(outputs, labels).mean()
|
||||
# distance = torch.norm(outputs - labels).mean()
|
||||
# distance = circle_norm(outputs, labels).mean()
|
||||
distance = torch.norm(outputs - labels).mean()
|
||||
|
||||
# Backprop with this:
|
||||
loss = (1 - alpha) * p_loss + alpha * distance
|
||||
|
@ -27,12 +27,12 @@ learning_rate_values = [1e-3]
|
||||
# learning_rate_values = [5e-4]
|
||||
|
||||
# alpha_values = [0, .25, 0.5, 0.75, 1] # alpha = 0 is unsupervised. alpha = 1 is supervised.
|
||||
alpha_values = [0]
|
||||
alpha_values = [1.0]
|
||||
# widths = [2**k for k in range(4, 13)]
|
||||
# depths = [1, 2, 4, 8, 16]
|
||||
widths, depths = [512], [4]
|
||||
|
||||
batch_size_values = [64, 256, 1024]
|
||||
batch_size_values = [256]
|
||||
max_epochs_values = [100]
|
||||
seeds = list(range(21, 1992))
|
||||
optimizers = [
|
||||
@ -73,7 +73,7 @@ for idx, params in enumerate(search_params):
|
||||
python newmain.py fit \
|
||||
--seed_everything {s} \
|
||||
--data.batch_size {bs} \
|
||||
--data.train_size 50000 \
|
||||
--data.train_size 0 \
|
||||
--data.val_size 10000 \
|
||||
--model.alpha {a} \
|
||||
--model.width {w} \
|
||||
|
Loading…
Reference in New Issue
Block a user