Compare commits
5 Commits
b6d9f94d8e
...
4342a54cc8
Author | SHA1 | Date | |
---|---|---|---|
|
4342a54cc8 | ||
|
248d1a72f9 | ||
|
9e4861a272 | ||
|
e5b6f287a3 | ||
|
865e7f5104 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ out/
|
||||
*.tar.gz
|
||||
.pat
|
||||
out*
|
||||
.lr*
|
||||
|
10
losses.py
10
losses.py
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from utils import PURE_RGB
|
||||
from utils import RGBMYC_ANCHOR
|
||||
|
||||
# def smoothness_loss(outputs):
|
||||
# # Sort outputs for smoothness calculation
|
||||
@ -40,11 +40,11 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
|
||||
# print(rgb_norm)
|
||||
|
||||
# Calculate 1D Space Norm (modulo 1 to account for circularity)
|
||||
transformed_norm = circle_norm(outputs, target_outputs)
|
||||
transformed_norm = circle_norm(outputs, target_outputs) * 2
|
||||
|
||||
diff = torch.pow(rgb_norm - transformed_norm, 2)
|
||||
N = len(outputs)
|
||||
return torch.sum(diff) / (N * (N - 1)) / 2
|
||||
N = torch.count_nonzero(rgb_norm)
|
||||
return torch.sum(diff) / N
|
||||
|
||||
|
||||
def circle_norm(vector, other_vector):
|
||||
@ -68,7 +68,7 @@ def separation_loss(red, green, blue):
|
||||
def calculate_separation_loss(model):
|
||||
# TODO: remove
|
||||
# Wrapper function to calculate separation loss
|
||||
outputs = model(PURE_RGB.to(model.device))
|
||||
outputs = model(RGBMYC_ANCHOR.to(model.device))
|
||||
red, green, blue = outputs[0], outputs[1], outputs[2]
|
||||
return separation_loss(red, green, blue)
|
||||
|
||||
|
2
makefile
2
makefile
@ -36,7 +36,7 @@ help:
|
||||
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
|
||||
python newmain.py fit --help
|
||||
|
||||
search:
|
||||
search: lint
|
||||
python newsearch.py
|
||||
|
||||
hsv:
|
||||
|
15
model.py
15
model.py
@ -3,9 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from losses import preservation_loss
|
||||
|
||||
# from utils import PURE_RGB
|
||||
from losses import circle_norm, preservation_loss
|
||||
from utils import RGBMYC_ANCHOR
|
||||
|
||||
|
||||
class ColorTransformerModel(L.LightningModule):
|
||||
@ -49,18 +48,18 @@ class ColorTransformerModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs, labels = batch # x are the RGB inputs, labels are the strings
|
||||
outputs = self.forward(inputs)
|
||||
# rgb_tensor = PURE_RGB.to(self.device)
|
||||
rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
|
||||
p_loss = preservation_loss(
|
||||
inputs,
|
||||
outputs,
|
||||
# target_inputs=rgb_tensor,
|
||||
# target_outputs=self.forward(rgb_tensor),
|
||||
target_inputs=rgb_tensor,
|
||||
target_outputs=self.forward(rgb_tensor),
|
||||
)
|
||||
alpha = self.hparams.alpha
|
||||
|
||||
# N = len(outputs)
|
||||
# distance = circle_norm(outputs, labels) / (N*(N-1)/2)
|
||||
distance = torch.norm(outputs - labels).mean()
|
||||
distance = circle_norm(outputs, labels).mean()
|
||||
# distance = torch.norm(outputs - labels).mean()
|
||||
|
||||
# Backprop with this:
|
||||
loss = (1 - alpha) * p_loss + alpha * distance
|
||||
|
@ -32,14 +32,14 @@ alpha_values = [0]
|
||||
# depths = [1, 2, 4, 8, 16]
|
||||
widths, depths = [512], [4]
|
||||
|
||||
batch_size_values = [256]
|
||||
batch_size_values = [64, 256, 1024]
|
||||
max_epochs_values = [100]
|
||||
seeds = list(range(21, 1992))
|
||||
optimizers = [
|
||||
# "Adagrad",
|
||||
"Adam",
|
||||
# "Adam",
|
||||
# "SGD",
|
||||
# "AdamW",
|
||||
"AdamW",
|
||||
# "LBFGS",
|
||||
# "RAdam",
|
||||
# "RMSprop",
|
||||
@ -73,7 +73,7 @@ for idx, params in enumerate(search_params):
|
||||
python newmain.py fit \
|
||||
--seed_everything {s} \
|
||||
--data.batch_size {bs} \
|
||||
--data.train_size 0 \
|
||||
--data.train_size 50000 \
|
||||
--data.val_size 10000 \
|
||||
--model.alpha {a} \
|
||||
--model.width {w} \
|
||||
|
Loading…
Reference in New Issue
Block a user