diff --git a/.gitignore b/.gitignore index 4b2cf13..29f5c53 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ lightning_logs/ __pycache__/ +out/ .sw[opqr] diff --git a/check.py b/check.py index ddf855b..60f5b71 100644 --- a/check.py +++ b/check.py @@ -1,25 +1,38 @@ import matplotlib.pyplot as plt import numpy as np +import torch from dataloader import extract_colors from model import ColorTransformerModel -name = "color_128_0.3_1.00e-06" -ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" -M = ColorTransformerModel.load_from_checkpoint(ckpt) +def make_image(ckpt: str, fname: str, color=True): + M = ColorTransformerModel.load_from_checkpoint(ckpt) -rgb_tensor, names = extract_colors() -preds = M(rgb_tensor) -rgb_values = rgb_tensor.detach().numpy() -sorted_inds = np.argsort(preds.detach().numpy().ravel()) + # preds = M(rgb_tensor) + if not color: + N = 949 + linear_space = torch.linspace(0, 1, N) + rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) + else: + rgb_tensor, names = extract_colors() + rgb_values = rgb_tensor.detach().numpy() + preds = M(rgb_tensor) + sorted_inds = np.argsort(preds.detach().numpy().ravel()) -fig, ax = plt.subplots(figsize=(10, 5)) -for i in range(len(sorted_inds)): - idx = sorted_inds[i] - color = rgb_values[idx] - ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx]) - ax.axis("off") - # ax.axis("square") + fig, ax = plt.subplots(figsize=(10, 5)) + for i in range(len(sorted_inds)): + idx = sorted_inds[i] + color = rgb_values[idx] + ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5) + ax.axis("off") + # ax.axis("square") -plt.savefig(f"{name}.png", dpi=300) + plt.savefig(f"{fname}.png", dpi=300) + + +if __name__ == "__main__": + + name = "color_128_0.3_1.00e-06" + ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" + make_image(ckpt, fname=name) diff --git a/scrape.py b/scrape.py new file mode 100644 index 0000000..1da8248 --- /dev/null +++ b/scrape.py @@ -0,0 +1,40 @@ +import glob +from pathlib import Path +import shutil +from check import make_image + + +def get_exps(pattern: str, splitter: str = "_"): + basedir = "/teamspace/jobs/" + chkpt_basedir = "/work/colors/lightning_logs/" + location = basedir + pattern + res = glob.glob(location) + location = location.replace('*', '') + H = [] # hyperparams used + # print(res) + for r in res: + d = r.replace(location, '').split(splitter) + d = list(float(_d) for _d in d) + d[0] = int(d[0]) + H.append(d) + for i, r in enumerate(res): + dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/") + dir_path.mkdir(parents=True, exist_ok=True) + g = glob.glob(r + chkpt_basedir + "*") + c = g[0] + "/checkpoints" + latest_checkpoint = glob.glob(c + "/*")[-1] + # print(latest_checkpoint) + logs = glob.glob(g[0] + "/events*")[-1] + print(logs) + source_path = Path(logs) + # print("Would copy", source_path, dir_path) + # shutil.copy(source_path, dir_path) + make_image(latest_checkpoint, f"out/version_{i}") + make_image(latest_checkpoint, f"out/version_{i}b", color=False) + + return H + + +if __name__ == "__main__": + D = get_exps("color_*", "_") + print(len(D), D)