Browse Source

dpi argument

new-sep-loss
Michael Pilosov 10 months ago
parent
commit
5ed305fe34
  1. 13
      check.py
  2. 4
      main.py
  3. 3
      makefile

13
check.py

@ -38,7 +38,7 @@ def make_image(ckpt: str, fname: str, color=True):
plt.savefig(f"{fname}.png", dpi=300) plt.savefig(f"{fname}.png", dpi=300)
def create_circle(ckpt: str, fname: str): def create_circle(ckpt: str, fname: str, dpi: int = 150):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
else: else:
@ -46,10 +46,10 @@ def create_circle(ckpt: str, fname: str):
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
preds = M(rgb_tensor.to(M.device)) preds = M(rgb_tensor.to(M.device))
plot_preds(preds, fname=fname) plot_preds(preds, fname=fname, dpi=dpi)
def plot_preds(preds, fname: str, roll: bool = False): def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150):
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
rgb_values = rgb_tensor.detach().numpy() rgb_values = rgb_tensor.detach().numpy()
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor)
@ -80,7 +80,7 @@ def plot_preds(preds, fname: str, roll: bool = False):
ax.set_yticks([]) ax.set_yticks([])
ax.axis("off") ax.axis("off")
fig.tight_layout() fig.tight_layout()
plt.savefig(f"{fname}.png", dpi=150) plt.savefig(f"{fname}.png", dpi=dpi)
plt.close() plt.close()
@ -92,6 +92,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# make the following accept a list of arguments # make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1])
parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image."
)
args = parser.parse_args() args = parser.parse_args()
versions = args.version versions = args.version
for v in versions: for v in versions:
@ -102,7 +105,7 @@ if __name__ == "__main__":
if len(ckpt) > 0: if len(ckpt) > 0:
ckpt = ckpt[-1] ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}") print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name) create_circle(ckpt, fname=name, dpi=args.dpi)
else: else:
print(f"No checkpoint found for version {v}") print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False) # make_image(ckpt, fname=name + "b", color=False)

4
main.py

@ -40,9 +40,7 @@ def parse_args():
default=3, default=3,
help="Number of workers for data loading", help="Number of workers for data loading",
) )
parser.add_argument( parser.add_argument("--width", type=int, default=128, help="Max width of network.")
"--width", type=int, default=128, help="Max width of network."
)
# Parse arguments # Parse arguments
args = parser.parse_args() args = parser.parse_args()

3
makefile

@ -16,6 +16,7 @@ animate:
~/animated.mp4 ~/animated.mp4
clean: clean:
rm -rf lightning_logs/* rm -rf lightning_logs
rm -f out/*.png rm -f out/*.png
rm -rf __pycache__/
cp hsv.png out/ cp hsv.png out/

Loading…
Cancel
Save