Compare commits

...

5 Commits

Author SHA1 Message Date
Michael Pilosov, PhD
4342a54cc8 anchors and more train data, exp w batch 2024-01-28 02:43:35 +00:00
Michael Pilosov, PhD
248d1a72f9 re-anchor, pretty meh results in-batch 2024-01-28 01:52:21 +00:00
Michael Pilosov, PhD
9e4861a272 try unsupervised again, but with 10k random samples 2024-01-28 01:33:10 +00:00
Michael Pilosov, PhD
e5b6f287a3 xkcd colors may be too few to learn from. need 10x 2024-01-28 01:32:56 +00:00
Michael Pilosov, PhD
865e7f5104 supervised questionable 2024-01-28 01:25:10 +00:00
6 changed files with 19 additions and 19 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ out/
*.tar.gz
.pat
out*
.lr*

View File

@ -1,6 +1,6 @@
import torch
from utils import PURE_RGB
from utils import RGBMYC_ANCHOR
# def smoothness_loss(outputs):
# # Sort outputs for smoothness calculation
@ -40,11 +40,11 @@ def preservation_loss(inputs, outputs, target_inputs=None, target_outputs=None):
# print(rgb_norm)
# Calculate 1D Space Norm (modulo 1 to account for circularity)
transformed_norm = circle_norm(outputs, target_outputs)
transformed_norm = circle_norm(outputs, target_outputs) * 2
diff = torch.pow(rgb_norm - transformed_norm, 2)
N = len(outputs)
return torch.sum(diff) / (N * (N - 1)) / 2
N = torch.count_nonzero(rgb_norm)
return torch.sum(diff) / N
def circle_norm(vector, other_vector):
@ -68,7 +68,7 @@ def separation_loss(red, green, blue):
def calculate_separation_loss(model):
# TODO: remove
# Wrapper function to calculate separation loss
outputs = model(PURE_RGB.to(model.device))
outputs = model(RGBMYC_ANCHOR.to(model.device))
red, green, blue = outputs[0], outputs[1], outputs[2]
return separation_loss(red, green, blue)

View File

@ -36,7 +36,7 @@ help:
# python newmain.py fit --lr_scheduler.help lightning.pytorch.cli.ReduceLROnPlateau
python newmain.py fit --help
search:
search: lint
python newsearch.py
hsv:

View File

@ -3,9 +3,8 @@ import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import preservation_loss
# from utils import PURE_RGB
from losses import circle_norm, preservation_loss
from utils import RGBMYC_ANCHOR
class ColorTransformerModel(L.LightningModule):
@ -49,18 +48,18 @@ class ColorTransformerModel(L.LightningModule):
def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs)
# rgb_tensor = PURE_RGB.to(self.device)
rgb_tensor = RGBMYC_ANCHOR.to(self.device) # noqa: F841
p_loss = preservation_loss(
inputs,
outputs,
# target_inputs=rgb_tensor,
# target_outputs=self.forward(rgb_tensor),
target_inputs=rgb_tensor,
target_outputs=self.forward(rgb_tensor),
)
alpha = self.hparams.alpha
# N = len(outputs)
# distance = circle_norm(outputs, labels) / (N*(N-1)/2)
distance = torch.norm(outputs - labels).mean()
distance = circle_norm(outputs, labels).mean()
# distance = torch.norm(outputs - labels).mean()
# Backprop with this:
loss = (1 - alpha) * p_loss + alpha * distance

View File

@ -32,14 +32,14 @@ alpha_values = [0]
# depths = [1, 2, 4, 8, 16]
widths, depths = [512], [4]
batch_size_values = [256]
batch_size_values = [64, 256, 1024]
max_epochs_values = [100]
seeds = list(range(21, 1992))
optimizers = [
# "Adagrad",
"Adam",
# "Adam",
# "SGD",
# "AdamW",
"AdamW",
# "LBFGS",
# "RAdam",
# "RMSprop",
@ -73,7 +73,7 @@ for idx, params in enumerate(search_params):
python newmain.py fit \
--seed_everything {s} \
--data.batch_size {bs} \
--data.train_size 0 \
--data.train_size 50000 \
--data.val_size 10000 \
--model.alpha {a} \
--model.width {w} \

View File

@ -34,7 +34,7 @@ def extract_colors():
return rgb_tensor, xkcd_color_names
PURE_RGB = preprocess_data(
RGBMYC_ANCHOR = preprocess_data(
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
)
PURE_HSV = torch.tensor(