|
@ -38,7 +38,7 @@ def make_image(ckpt: str, fname: str, color=True): |
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_circle(ckpt: str, fname: str): |
|
|
def create_circle(ckpt: str, fname: str, dpi: int = 150): |
|
|
if isinstance(ckpt, str): |
|
|
if isinstance(ckpt, str): |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
else: |
|
|
else: |
|
@ -46,10 +46,10 @@ def create_circle(ckpt: str, fname: str): |
|
|
|
|
|
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_tensor, _ = extract_colors() |
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
plot_preds(preds, fname=fname) |
|
|
plot_preds(preds, fname=fname, dpi=dpi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_preds(preds, fname: str, roll: bool = False): |
|
|
def plot_preds(preds, fname: str, roll: bool = False, dpi: int = 150): |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_tensor, _ = extract_colors() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
@ -80,7 +80,7 @@ def plot_preds(preds, fname: str, roll: bool = False): |
|
|
ax.set_yticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.axis("off") |
|
|
ax.axis("off") |
|
|
fig.tight_layout() |
|
|
fig.tight_layout() |
|
|
plt.savefig(f"{fname}.png", dpi=150) |
|
|
plt.savefig(f"{fname}.png", dpi=dpi) |
|
|
plt.close() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -92,6 +92,9 @@ if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser = argparse.ArgumentParser() |
|
|
# make the following accept a list of arguments |
|
|
# make the following accept a list of arguments |
|
|
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) |
|
|
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) |
|
|
|
|
|
parser.add_argument( |
|
|
|
|
|
"--dpi", type=int, default=150, help="Resolution for saved image." |
|
|
|
|
|
) |
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
versions = args.version |
|
|
versions = args.version |
|
|
for v in versions: |
|
|
for v in versions: |
|
@ -102,7 +105,7 @@ if __name__ == "__main__": |
|
|
if len(ckpt) > 0: |
|
|
if len(ckpt) > 0: |
|
|
ckpt = ckpt[-1] |
|
|
ckpt = ckpt[-1] |
|
|
print(f"Generating image for checkpoint: {ckpt}") |
|
|
print(f"Generating image for checkpoint: {ckpt}") |
|
|
create_circle(ckpt, fname=name) |
|
|
create_circle(ckpt, fname=name, dpi=args.dpi) |
|
|
else: |
|
|
else: |
|
|
print(f"No checkpoint found for version {v}") |
|
|
print(f"No checkpoint found for version {v}") |
|
|
# make_image(ckpt, fname=name + "b", color=False) |
|
|
# make_image(ckpt, fname=name + "b", color=False) |
|
|