Browse Source

this looks good

new-sep-loss
mm 11 months ago
parent
commit
fa0ba0cee7
  1. 15
      check.py
  2. 22
      dataloader.py
  3. 31
      losses.py
  4. 7
      main.py
  5. 2
      makefile
  6. 18
      model.py
  7. 13
      scrape.py

15
check.py

@ -2,9 +2,10 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from dataloader import extract_colors from dataloader import extract_colors, preprocess_data
from model import ColorTransformerModel from model import ColorTransformerModel
def make_image(ckpt: str, fname: str, color=True): def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
@ -17,14 +18,15 @@ def make_image(ckpt: str, fname: str, color=True):
rgb_tensor, names = extract_colors() rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy() rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor) preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel()) sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots(figsize=(10, 5)) fig, ax = plt.subplots(figsize=(20, 5))
for i in range(len(sorted_inds)): for i in range(len(sorted_inds)):
idx = sorted_inds[i] idx = sorted_inds[i]
color = rgb_values[idx] color = rgb_values[idx]
ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5) ax.plot([i, i],[0, 5], lw=0.5, c=color, antialiased=False, alpha=1)
ax.axis("off") ax.axis("off")
# ax.axis("square") # ax.axis("square")
@ -32,7 +34,8 @@ def make_image(ckpt: str, fname: str, color=True):
if __name__ == "__main__": if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
name = "color_128_0.3_1.00e-06" name = "color_64_1_1.0e-3.png"
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt = "/teamspace/studios/this_studio/colors/lightning_logs/version_26/checkpoints/epoch=99-step=1500.ckpt"
make_image(ckpt, fname=name) make_image(ckpt, fname=name)

22
dataloader.py

@ -18,14 +18,16 @@ def extract_colors():
def create_dataloader(**kwargs): def create_dataloader(**kwargs):
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
# Creating a dataset and data loader # Creating a dataset and data loader
dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor))) # Dummy labels dataset = TensorDataset(rgb_tensor, torch.zeros(len(rgb_tensor)))
train_dataloader = DataLoader(dataset, **kwargs) train_dataloader = DataLoader(dataset, **kwargs)
return train_dataloader return train_dataloader
def create_named_dataloader(**kwargs): def create_named_dataloader(**kwargs):
rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
# Creating a dataset with RGB values and their corresponding color names # Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [ dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) (rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor))
@ -34,6 +36,24 @@ def create_named_dataloader(**kwargs):
return train_dataloader_with_names return train_dataloader_with_names
def preprocess_data(data):
# Assuming 'data' is a tensor of shape [n_samples, 3]
# Compute argmin and argmax for each row
argmin_values = torch.argmin(data, dim=1, keepdim=True).float()
argmax_values = torch.argmax(data, dim=1, keepdim=True).float()
# Normalize or scale argmin and argmax if necessary
# For example, here I am just dividing by the number of features
argmin_values /= data.shape[1]
argmax_values /= data.shape[1]
# Concatenate the argmin and argmax values to the original data
new_data = torch.cat((data, argmin_values, argmax_values), dim=1)
return new_data
if __name__ == "__main__": if __name__ == "__main__":
batch_size = 4 batch_size = 4
train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True) train_dataloader = create_dataloader(batch_size=batch_size, shuffle=True)

31
losses.py

@ -1,16 +1,15 @@
import torch import torch
# def weighted_loss(inputs, outputs, alpha):
# # Calculate RGB Norm (Perceptual Difference)
# rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
def weighted_loss(inputs, outputs, alpha): # # Calculate 1D Space Norm
# Calculate RGB Norm (Perceptual Difference) # transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
# Calculate 1D Space Norm
transformed_norm = torch.norm(outputs[:, None] - outputs[None, :], dim=-1)
# Weighted Loss # # Weighted Loss
loss = alpha * rgb_norm + (1 - alpha) * transformed_norm # loss = alpha * rgb_norm + (1 - alpha) * transformed_norm
return torch.mean(loss) # return torch.mean(loss)
# def enhanced_loss(inputs, outputs, alpha, distinct_threshold): # def enhanced_loss(inputs, outputs, alpha, distinct_threshold):
@ -33,7 +32,7 @@ def weighted_loss(inputs, outputs, alpha):
# return torch.mean(loss) # return torch.mean(loss)
def enhanced_loss(inputs, outputs, alpha): def preservation_loss(inputs, outputs):
# Calculate RGB Norm # Calculate RGB Norm
rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1) rgb_norm = torch.norm(inputs[:, None, :] - inputs[None, :, :], dim=-1)
@ -42,19 +41,19 @@ def enhanced_loss(inputs, outputs, alpha):
# Distance Preservation Component # Distance Preservation Component
# Encourages the model to keep relative distances from the RGB space in the transformed space # Encourages the model to keep relative distances from the RGB space in the transformed space
distance_preservation_loss = torch.mean(torch.abs(rgb_norm - transformed_norm)) return torch.mean(torch.abs(rgb_norm - transformed_norm))
# Combined Loss
loss = alpha * distance_preservation_loss + (1 - alpha) * smoothness_loss(outputs)
return loss
def smoothness_loss(outputs): def smoothness_loss(outputs):
# Sort outputs for smoothness calculation # Sort outputs for smoothness calculation
sorted_outputs, _ = torch.sort(outputs, dim=0) sorted_outputs, _ = torch.sort(outputs, dim=0)
first_elements = sorted_outputs[:2]
# Concatenate the first element at the end of the sorted_outputs
extended_sorted_outputs = torch.cat((sorted_outputs, first_elements), dim=0)
# Calculate smoothness in the sorted outputs # Calculate smoothness in the sorted outputs
first_derivative = torch.diff(sorted_outputs, n=1, dim=0) first_derivative = torch.diff(extended_sorted_outputs, n=1, dim=0)
second_derivative = torch.diff(first_derivative, n=1, dim=0) second_derivative = torch.diff(first_derivative, n=1, dim=0)
smoothness_loss = torch.mean(torch.abs(second_derivative)) smoothness_loss = torch.mean(torch.abs(second_derivative))
return smoothness_loss return smoothness_loss

7
main.py

@ -50,12 +50,15 @@ if __name__ == "__main__":
num_workers=args.num_workers, num_workers=args.num_workers,
) )
# Initialize model with parsed arguments params = argparse.Namespace(
model = ColorTransformerModel(
alpha=args.alpha, alpha=args.alpha,
learning_rate=args.lr, learning_rate=args.lr,
batch_size=args.bs,
) )
# Initialize model with parsed arguments
model = ColorTransformerModel(params)
# Initialize trainer with parsed arguments # Initialize trainer with parsed arguments
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

2
makefile

@ -4,4 +4,4 @@ lint:
flake8 --ignore E501 . flake8 --ignore E501 .
test: test:
python main.py --alpha 0.7 --lr 1e-3 --max_epochs 1000 python main.py --alpha 1 --lr 1e-3 --max_epochs 100

18
model.py

@ -3,17 +3,17 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import enhanced_loss, weighted_loss # noqa: F401 from losses import preservation_loss, smoothness_loss
class ColorTransformerModel(pl.LightningModule): class ColorTransformerModel(pl.LightningModule):
def __init__(self, alpha, learning_rate): def __init__(self, params):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters(params)
# Model layers # Model layers
self.layers = nn.Sequential( self.layers = nn.Sequential(
nn.Linear(3, 128), nn.Linear(5, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, 128), nn.Linear(128, 128),
nn.ReLU(), nn.ReLU(),
@ -22,7 +22,7 @@ class ColorTransformerModel(pl.LightningModule):
def forward(self, x): def forward(self, x):
x = self.layers(x) x = self.layers(x)
x = torch.sigmoid(x) x = (torch.sin(x) + 1) / 2
return x return x
# class ColorTransformerModel(pl.LightningModule): # class ColorTransformerModel(pl.LightningModule):
@ -68,12 +68,14 @@ class ColorTransformerModel(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs, labels = batch # x are the RGB inputs, labels are the strings inputs, labels = batch # x are the RGB inputs, labels are the strings
outputs = self.forward(inputs) outputs = self.forward(inputs)
# loss = weighted_loss(inputs, outputs, alpha=self.hparams.alpha) s_loss = smoothness_loss(outputs)
loss = enhanced_loss( p_loss = preservation_loss(
inputs, inputs,
outputs, outputs,
alpha=self.hparams.alpha,
) )
alpha = self.hparams.alpha
loss = p_loss + alpha * s_loss
self.log("hp_metric", p_loss)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss

13
scrape.py

@ -1,6 +1,7 @@
import glob import glob
from pathlib import Path
import shutil import shutil
from pathlib import Path
from check import make_image from check import make_image
@ -9,16 +10,18 @@ def get_exps(pattern: str, splitter: str = "_"):
chkpt_basedir = "/work/colors/lightning_logs/" chkpt_basedir = "/work/colors/lightning_logs/"
location = basedir + pattern location = basedir + pattern
res = glob.glob(location) res = glob.glob(location)
location = location.replace('*', '') location = location.replace("*", "")
H = [] # hyperparams used H = [] # hyperparams used
# print(res) # print(res)
for r in res: for r in res:
d = r.replace(location, '').split(splitter) d = r.replace(location, "").split(splitter)
d = list(float(_d) for _d in d) d = list(float(_d) for _d in d)
d[0] = int(d[0]) d[0] = int(d[0])
H.append(d) H.append(d)
for i, r in enumerate(res): for i, r in enumerate(res):
dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/") dir_path = Path(
f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/"
)
dir_path.mkdir(parents=True, exist_ok=True) dir_path.mkdir(parents=True, exist_ok=True)
g = glob.glob(r + chkpt_basedir + "*") g = glob.glob(r + chkpt_basedir + "*")
c = g[0] + "/checkpoints" c = g[0] + "/checkpoints"
@ -26,7 +29,7 @@ def get_exps(pattern: str, splitter: str = "_"):
# print(latest_checkpoint) # print(latest_checkpoint)
logs = glob.glob(g[0] + "/events*")[-1] logs = glob.glob(g[0] + "/events*")[-1]
print(logs) print(logs)
source_path = Path(logs) # source_path = Path(logs)
# print("Would copy", source_path, dir_path) # print("Would copy", source_path, dir_path)
# shutil.copy(source_path, dir_path) # shutil.copy(source_path, dir_path)
make_image(latest_checkpoint, f"out/version_{i}") make_image(latest_checkpoint, f"out/version_{i}")

Loading…
Cancel
Save