From 6e4f0f646d3c7ac0b3ffd4074a24e65571725078 Mon Sep 17 00:00:00 2001 From: mm Date: Sun, 31 Dec 2023 07:00:25 +0000 Subject: [PATCH] add in grayscale --- dataloader.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dataloader.py b/dataloader.py index 9fabf23..6ca33af 100644 --- a/dataloader.py +++ b/dataloader.py @@ -16,7 +16,7 @@ def extract_colors(): return rgb_tensor, xkcd_color_names -def create_dataloader(**kwargs): +def create_dataloader(N: int = 50, **kwargs): rgb_tensor, _ = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset and data loader @@ -25,13 +25,21 @@ def create_dataloader(**kwargs): return train_dataloader -def create_named_dataloader(**kwargs): +def create_gray_supplement(N: int = 50): + linear_space = torch.linspace(0, 1, N) + gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) + return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] + + +def create_named_dataloader(N: int = 50, **kwargs): rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) # Creating a dataset with RGB values and their corresponding color names dataset_with_names = [ - (rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) + (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) + for i in range(len(rgb_tensor)) ] + dataset_with_names += create_gray_supplement(N) train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) return train_dataloader_with_names