mm
11 months ago
3 changed files with 69 additions and 15 deletions
@ -1,3 +1,4 @@ |
|||
lightning_logs/ |
|||
__pycache__/ |
|||
out/ |
|||
.sw[opqr] |
|||
|
@ -1,25 +1,38 @@ |
|||
import matplotlib.pyplot as plt |
|||
import numpy as np |
|||
import torch |
|||
|
|||
from dataloader import extract_colors |
|||
from model import ColorTransformerModel |
|||
|
|||
name = "color_128_0.3_1.00e-06" |
|||
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
|||
def make_image(ckpt: str, fname: str, color=True): |
|||
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|||
|
|||
# preds = M(rgb_tensor) |
|||
if not color: |
|||
N = 949 |
|||
linear_space = torch.linspace(0, 1, N) |
|||
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) |
|||
else: |
|||
rgb_tensor, names = extract_colors() |
|||
preds = M(rgb_tensor) |
|||
|
|||
rgb_values = rgb_tensor.detach().numpy() |
|||
preds = M(rgb_tensor) |
|||
sorted_inds = np.argsort(preds.detach().numpy().ravel()) |
|||
|
|||
|
|||
fig, ax = plt.subplots(figsize=(10, 5)) |
|||
for i in range(len(sorted_inds)): |
|||
idx = sorted_inds[i] |
|||
color = rgb_values[idx] |
|||
ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx]) |
|||
ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5) |
|||
ax.axis("off") |
|||
# ax.axis("square") |
|||
|
|||
plt.savefig(f"{name}.png", dpi=300) |
|||
plt.savefig(f"{fname}.png", dpi=300) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
|
|||
name = "color_128_0.3_1.00e-06" |
|||
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
|||
make_image(ckpt, fname=name) |
|||
|
@ -0,0 +1,40 @@ |
|||
import glob |
|||
from pathlib import Path |
|||
import shutil |
|||
from check import make_image |
|||
|
|||
|
|||
def get_exps(pattern: str, splitter: str = "_"): |
|||
basedir = "/teamspace/jobs/" |
|||
chkpt_basedir = "/work/colors/lightning_logs/" |
|||
location = basedir + pattern |
|||
res = glob.glob(location) |
|||
location = location.replace('*', '') |
|||
H = [] # hyperparams used |
|||
# print(res) |
|||
for r in res: |
|||
d = r.replace(location, '').split(splitter) |
|||
d = list(float(_d) for _d in d) |
|||
d[0] = int(d[0]) |
|||
H.append(d) |
|||
for i, r in enumerate(res): |
|||
dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/") |
|||
dir_path.mkdir(parents=True, exist_ok=True) |
|||
g = glob.glob(r + chkpt_basedir + "*") |
|||
c = g[0] + "/checkpoints" |
|||
latest_checkpoint = glob.glob(c + "/*")[-1] |
|||
# print(latest_checkpoint) |
|||
logs = glob.glob(g[0] + "/events*")[-1] |
|||
print(logs) |
|||
source_path = Path(logs) |
|||
# print("Would copy", source_path, dir_path) |
|||
# shutil.copy(source_path, dir_path) |
|||
make_image(latest_checkpoint, f"out/version_{i}") |
|||
make_image(latest_checkpoint, f"out/version_{i}b", color=False) |
|||
|
|||
return H |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
D = get_exps("color_*", "_") |
|||
print(len(D), D) |
Loading…
Reference in new issue