diff --git a/baseline.py b/baseline.py index 216808b..df2bf77 100644 --- a/baseline.py +++ b/baseline.py @@ -2,7 +2,6 @@ import argparse from pathlib import Path import matplotlib.colors as mcolors -import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np from hilbertcurve.hilbertcurve import HilbertCurve @@ -35,7 +34,7 @@ DPI = args.dpi SIZE = args.size FONTSIZE = args.fontsize INNER_RADIUS = args.radius -DIR = "/teamspace/studios/colors/umap" +DIR = "/teamspace/studios/this_studio/colors/colors-umap" prefix = "" @@ -46,7 +45,6 @@ Path(FDIR).mkdir(exist_ok=True, parents=True) fname = f"{FDIR}/{prefix}sorted_colors_circle.png" - if KIND in ("lex", "alpha", "abc"): preds = np.array(colors) diff --git a/check.py b/check.py index 6160e39..cefe554 100644 --- a/check.py +++ b/check.py @@ -1,6 +1,6 @@ # import matplotlib.patches as patches from pathlib import Path -from typing import Union, Tuple +from typing import Tuple, Union import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -152,27 +152,34 @@ if __name__ == "__main__": parser.add_argument( "--dpi", type=int, default=300, help="Resolution for saved image." ) + parser.add_argument( + "--studio", + type=str, + default="this_studio", + nargs="+", + help="Checkpoint studio name.", + ) parser.add_argument("--figsize", type=int, default=6, help="Figure size") args = parser.parse_args() versions = args.version - for v in versions: - # name = f"out/v{v}" - studio = "colors-refactor-supervised" - # studio = "colors-refactor-unsupervised" - # studio = "colors-refactor-unsupervised-anchors" - # studio = "this_studio" + for studio in args.studio: Path(studio).mkdir(exist_ok=True, parents=True) - name = f"{studio}/v{v}" - # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" - # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" - ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" - ckpt = glob.glob(ckpt_path) - if len(ckpt) > 0: - ckpt = ckpt[-1] - print(f"Generating image for checkpoint: {ckpt}") - create_circle( - ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False - ) - else: - print(f"No checkpoint found for version {v}") - # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,) + for v in versions: + name = f"{studio}/v{v}" + # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" + # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" + ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" + ckpt = glob.glob(ckpt_path) + if len(ckpt) > 0: + ckpt = ckpt[-1] + print(f"Generating image for checkpoint: {ckpt}") + create_circle( + ckpt, + fname=name, + dpi=args.dpi, + figsize=[args.figsize] * 2, + roll=False, + ) + else: + print(f"No checkpoint found for version {v}") + # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,) diff --git a/makefile b/makefile index 6c9e806..bd486b5 100644 --- a/makefile +++ b/makefile @@ -62,7 +62,9 @@ parallel_umap: parallel -j 4 python baseline.py -s umap --dpi 300 --seed ::: $$(seq 1 100) parallel_check: - parallel -j 4 python check.py -v ::: $$(seq 0 99) + parallel -j 3 python check.py \ + --studio colors-refactor-unsupervised colors-refactor-supervised colors-refactor-unsupervised-anchors \ + -v ::: $$(seq 0 99) sort_lex: python scripts/sortcolor.py -s lex --dpi 300