diff --git a/check.py b/check.py index afefcda..693daf2 100644 --- a/check.py +++ b/check.py @@ -46,24 +46,27 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150): M = ckpt rgb_tensor, _ = extract_colors() + rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor.to(M.device)) - plot_preds(preds, fname=fname, dpi=dpi) - + plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, dpi=dpi) -def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)): - rgb_tensor, _ = extract_colors() - rgb_values = rgb_tensor.detach().numpy() - rgb_tensor = preprocess_data(rgb_tensor) +def plot_preds( + preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3) +): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() sorted_inds = np.argsort(preds.ravel()) - colors = rgb_values[sorted_inds] + colors = rgb_values[sorted_inds, :] if roll: # find white in colors, put it first. white = np.array([1, 1, 1]) - white_idx = np.where((colors == white).all(axis=1))[0][0] - colors = np.roll(colors, -white_idx, axis=0) + white_idx = np.where((colors == white).all(axis=1)) + if white_idx: + white_idx = white_idx[0][0] + colors = np.roll(colors, -white_idx, axis=0) + else: + print("no white, skipping") # print(white_idx, colors[:2]) N = len(colors) @@ -76,10 +79,13 @@ def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150, figsize=(3 for i in range(N): ax.bar( + # 2 * np.pi * preds[i], theta[i], - 1, + height=1, width=width, edgecolor="none", + # facecolor=[rgb_values[i][1]]*3, + # facecolor=rgb_values[i], facecolor=colors[i], bottom=0.0, zorder=1, diff --git a/hsv.png b/hsv.png index a23a00c..bc15ae4 100644 Binary files a/hsv.png and b/hsv.png differ diff --git a/hsv.py b/hsv.py index d453312..720ba79 100644 --- a/hsv.py +++ b/hsv.py @@ -7,8 +7,11 @@ from utils import extract_colors if __name__ == "__main__": rgb_tensor, _ = extract_colors() xkcd_rgb = rgb_tensor.numpy() + # xkcd_rgb = np.random.rand(1000, 3) xkcd_hsv = rgb_to_hsv(xkcd_rgb) - plot_preds(xkcd_hsv[:, 0], fname="hsv", roll=True, dpi=150, figsize=(6, 6)) + plot_preds( + xkcd_hsv[:, 0], xkcd_rgb, fname="hsv", roll=False, dpi=300, figsize=(6, 6) + ) rgb = np.eye(3) print("Pure RGB in Hue-Space:") print(rgb_to_hsv(rgb)[:, 0]) diff --git a/main.py b/main.py index 513b79f..1c31d5a 100644 --- a/main.py +++ b/main.py @@ -72,13 +72,13 @@ if __name__ == "__main__": save_img_callback = SaveImageCallback( save_interval=0, - final_dir="out", + final_dir=None, ) # Initialize data loader with parsed arguments # named_data_loader also has grayscale extras. TODO: remove unnamed train_dataloader = create_dataloader( - N=1e7, + N=1e5, batch_size=args.bs, shuffle=True, num_workers=args.num_workers, diff --git a/out/index.html b/out/index.html index 1f580c3..1209be9 100644 --- a/out/index.html +++ b/out/index.html @@ -74,7 +74,7 @@ for (var i = 0; i < 100; i++) { // Changed from i <= 100 to i < 100 let imageName; - if (i == 21) { + if (i == -21) { imageName = 'hsv.png'; } else { imageName = 'v' + i + '.png';