diff --git a/utils.py b/utils.py index b3be461..d5c70db 100644 --- a/utils.py +++ b/utils.py @@ -34,7 +34,7 @@ def extract_colors(): return rgb_tensor, xkcd_color_names -PURE_RGB = preprocess_data(torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]) +PURE_RGB = preprocess_data(torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)) PURE_HSV = torch.tensor( [[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32 )