diff --git a/check.py b/check.py index 025bdde..a10464f 100644 --- a/check.py +++ b/check.py @@ -10,7 +10,7 @@ from model import ColorTransformerModel # import matplotlib.colors as mcolors -def make_image(ckpt: str, fname: str, color=True): +def make_image(ckpt: str, fname: str, color=True, **kwargs): M = ColorTransformerModel.load_from_checkpoint(ckpt) # preds = M(rgb_tensor) @@ -36,10 +36,10 @@ def make_image(ckpt: str, fname: str, color=True): ax.axis("off") # ax.axis("square") - plt.savefig(f"{fname}.png", dpi=300) + plt.savefig(f"{fname}.png", **kwargs) -def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True): +def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint(ckpt) else: @@ -48,7 +48,7 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True): rgb_tensor, _ = extract_colors() rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor.to(M.device)) - plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, dpi=dpi) + plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs) def plot_preds( @@ -116,10 +116,11 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() # make the following accept a list of arguments - parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) + parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument( "--dpi", type=int, default=150, help="Resolution for saved image." ) + parser.add_argument("--figsize", type=int, default=3, help="Figure size") args = parser.parse_args() versions = args.version for v in versions: @@ -130,7 +131,7 @@ if __name__ == "__main__": if len(ckpt) > 0: ckpt = ckpt[-1] print(f"Generating image for checkpoint: {ckpt}") - create_circle(ckpt, fname=name, dpi=args.dpi) + create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2) else: print(f"No checkpoint found for version {v}") - # make_image(ckpt, fname=name + "b", color=False) + # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,) diff --git a/out/index.html b/out/index.html index 1209be9..e9e18ad 100644 --- a/out/index.html +++ b/out/index.html @@ -77,7 +77,7 @@ if (i == -21) { imageName = 'hsv.png'; } else { - imageName = 'v' + i + '.png'; + imageName = 'v' + i + '.jpg'; } let img = document.createElement('img'); img.src = imageName; diff --git a/utils.py b/utils.py index d5c70db..3522a76 100644 --- a/utils.py +++ b/utils.py @@ -34,7 +34,9 @@ def extract_colors(): return rgb_tensor, xkcd_color_names -PURE_RGB = preprocess_data(torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)) +PURE_RGB = preprocess_data( + torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0) +) PURE_HSV = torch.tensor( [[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32 )