Browse Source

add in grayscale

new-sep-loss
mm 11 months ago
parent
commit
6e4f0f646d
  1. 14
      dataloader.py

14
dataloader.py

@ -16,7 +16,7 @@ def extract_colors():
return rgb_tensor, xkcd_color_names
def create_dataloader(**kwargs):
def create_dataloader(N: int = 50, **kwargs):
rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
# Creating a dataset and data loader
@ -25,13 +25,21 @@ def create_dataloader(**kwargs):
return train_dataloader
def create_named_dataloader(**kwargs):
def create_gray_supplement(N: int = 50):
linear_space = torch.linspace(0, 1, N)
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3)
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))]
def create_named_dataloader(N: int = 50, **kwargs):
rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
# Creating a dataset with RGB values and their corresponding color names
dataset_with_names = [
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor))
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
for i in range(len(rgb_tensor))
]
dataset_with_names += create_gray_supplement(N)
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs)
return train_dataloader_with_names

Loading…
Cancel
Save