|
|
@ -1,6 +1,6 @@ |
|
|
|
# import matplotlib.patches as patches |
|
|
|
from pathlib import Path |
|
|
|
from typing import Union, Tuple |
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
|
|
import matplotlib.patches as patches |
|
|
|
import matplotlib.pyplot as plt |
|
|
@ -152,16 +152,19 @@ if __name__ == "__main__": |
|
|
|
parser.add_argument( |
|
|
|
"--dpi", type=int, default=300, help="Resolution for saved image." |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--studio", |
|
|
|
type=str, |
|
|
|
default="this_studio", |
|
|
|
nargs="+", |
|
|
|
help="Checkpoint studio name.", |
|
|
|
) |
|
|
|
parser.add_argument("--figsize", type=int, default=6, help="Figure size") |
|
|
|
args = parser.parse_args() |
|
|
|
versions = args.version |
|
|
|
for v in versions: |
|
|
|
# name = f"out/v{v}" |
|
|
|
studio = "colors-refactor-supervised" |
|
|
|
# studio = "colors-refactor-unsupervised" |
|
|
|
# studio = "colors-refactor-unsupervised-anchors" |
|
|
|
# studio = "this_studio" |
|
|
|
for studio in args.studio: |
|
|
|
Path(studio).mkdir(exist_ok=True, parents=True) |
|
|
|
for v in versions: |
|
|
|
name = f"{studio}/v{v}" |
|
|
|
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
|
|
|
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" |
|
|
@ -171,7 +174,11 @@ if __name__ == "__main__": |
|
|
|
ckpt = ckpt[-1] |
|
|
|
print(f"Generating image for checkpoint: {ckpt}") |
|
|
|
create_circle( |
|
|
|
ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False |
|
|
|
ckpt, |
|
|
|
fname=name, |
|
|
|
dpi=args.dpi, |
|
|
|
figsize=[args.figsize] * 2, |
|
|
|
roll=False, |
|
|
|
) |
|
|
|
else: |
|
|
|
print(f"No checkpoint found for version {v}") |
|
|
|