From 5ed305fe34f4888c510143cfd3dca53d17e56e67 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Mon, 15 Jan 2024 19:02:26 +0000 Subject: [PATCH] dpi argument --- check.py | 13 ++++++++----- main.py | 4 +--- makefile | 3 ++- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/check.py b/check.py index cc715ac..c3c452d 100644 --- a/check.py +++ b/check.py @@ -38,7 +38,7 @@ def make_image(ckpt: str, fname: str, color=True): plt.savefig(f"{fname}.png", dpi=300) -def create_circle(ckpt: str, fname: str): +def create_circle(ckpt: str, fname: str, dpi: int = 150): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint(ckpt) else: @@ -46,10 +46,10 @@ def create_circle(ckpt: str, fname: str): rgb_tensor, _ = extract_colors() preds = M(rgb_tensor.to(M.device)) - plot_preds(preds, fname=fname) + plot_preds(preds, fname=fname, dpi=dpi) -def plot_preds(preds, fname: str, roll: bool = False): +def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): rgb_tensor, _ = extract_colors() rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor) @@ -80,7 +80,7 @@ def plot_preds(preds, fname: str, roll: bool = False): ax.set_yticks([]) ax.axis("off") fig.tight_layout() - plt.savefig(f"{fname}.png", dpi=150) + plt.savefig(f"{fname}.png", dpi=dpi) plt.close() @@ -92,6 +92,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() # make the following accept a list of arguments parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) + parser.add_argument( + "--dpi", type=int, default=150, help="Resolution for saved image." + ) args = parser.parse_args() versions = args.version for v in versions: @@ -102,7 +105,7 @@ if __name__ == "__main__": if len(ckpt) > 0: ckpt = ckpt[-1] print(f"Generating image for checkpoint: {ckpt}") - create_circle(ckpt, fname=name) + create_circle(ckpt, fname=name, dpi=args.dpi) else: print(f"No checkpoint found for version {v}") # make_image(ckpt, fname=name + "b", color=False) diff --git a/main.py b/main.py index 84915d2..8fe580c 100644 --- a/main.py +++ b/main.py @@ -40,9 +40,7 @@ def parse_args(): default=3, help="Number of workers for data loading", ) - parser.add_argument( - "--width", type=int, default=128, help="Max width of network." - ) + parser.add_argument("--width", type=int, default=128, help="Max width of network.") # Parse arguments args = parser.parse_args() diff --git a/makefile b/makefile index 09b9dbe..8a1eb51 100644 --- a/makefile +++ b/makefile @@ -16,6 +16,7 @@ animate: ~/animated.mp4 clean: - rm -rf lightning_logs/* + rm -rf lightning_logs rm -f out/*.png + rm -rf __pycache__/ cp hsv.png out/