|
|
@ -16,7 +16,7 @@ def extract_colors(): |
|
|
|
return rgb_tensor, xkcd_color_names |
|
|
|
|
|
|
|
|
|
|
|
def create_dataloader(**kwargs): |
|
|
|
def create_dataloader(N: int = 50, **kwargs): |
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
# Creating a dataset and data loader |
|
|
@ -25,13 +25,21 @@ def create_dataloader(**kwargs): |
|
|
|
return train_dataloader |
|
|
|
|
|
|
|
|
|
|
|
def create_named_dataloader(**kwargs): |
|
|
|
def create_gray_supplement(N: int = 50): |
|
|
|
linear_space = torch.linspace(0, 1, N) |
|
|
|
gray_tensor = linear_space.unsqueeze(1).repeat(1, 3) |
|
|
|
return [(gray_tensor[i], f"gray{i/N:2.4f}") for i in range(len(gray_tensor))] |
|
|
|
|
|
|
|
|
|
|
|
def create_named_dataloader(N: int = 50, **kwargs): |
|
|
|
rgb_tensor, xkcd_color_names = extract_colors() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
# Creating a dataset with RGB values and their corresponding color names |
|
|
|
dataset_with_names = [ |
|
|
|
(rgb_tensor[i], xkcd_color_names[i]) for i in range(len(rgb_tensor)) |
|
|
|
(rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) |
|
|
|
for i in range(len(rgb_tensor)) |
|
|
|
] |
|
|
|
dataset_with_names += create_gray_supplement(N) |
|
|
|
train_dataloader_with_names = DataLoader(dataset_with_names, **kwargs) |
|
|
|
return train_dataloader_with_names |
|
|
|
|
|
|
|