Browse Source

scraping code, generate images

new-sep-loss
mm 11 months ago
parent
commit
1e63991960
  1. 1
      .gitignore
  2. 43
      check.py
  3. 40
      scrape.py

1
.gitignore

@ -1,3 +1,4 @@
lightning_logs/
__pycache__/
out/
.sw[opqr]

43
check.py

@ -1,25 +1,38 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from dataloader import extract_colors
from model import ColorTransformerModel
name = "color_128_0.3_1.00e-06"
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
M = ColorTransformerModel.load_from_checkpoint(ckpt)
def make_image(ckpt: str, fname: str, color=True):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
rgb_tensor, names = extract_colors()
preds = M(rgb_tensor)
rgb_values = rgb_tensor.detach().numpy()
sorted_inds = np.argsort(preds.detach().numpy().ravel())
# preds = M(rgb_tensor)
if not color:
N = 949
linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else:
rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy()
preds = M(rgb_tensor)
sorted_inds = np.argsort(preds.detach().numpy().ravel())
fig, ax = plt.subplots(figsize=(10, 5))
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx])
ax.axis("off")
# ax.axis("square")
fig, ax = plt.subplots(figsize=(10, 5))
for i in range(len(sorted_inds)):
idx = sorted_inds[i]
color = rgb_values[idx]
ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5)
ax.axis("off")
# ax.axis("square")
plt.savefig(f"{name}.png", dpi=300)
plt.savefig(f"{fname}.png", dpi=300)
if __name__ == "__main__":
name = "color_128_0.3_1.00e-06"
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
make_image(ckpt, fname=name)

40
scrape.py

@ -0,0 +1,40 @@
import glob
from pathlib import Path
import shutil
from check import make_image
def get_exps(pattern: str, splitter: str = "_"):
basedir = "/teamspace/jobs/"
chkpt_basedir = "/work/colors/lightning_logs/"
location = basedir + pattern
res = glob.glob(location)
location = location.replace('*', '')
H = [] # hyperparams used
# print(res)
for r in res:
d = r.replace(location, '').split(splitter)
d = list(float(_d) for _d in d)
d[0] = int(d[0])
H.append(d)
for i, r in enumerate(res):
dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/")
dir_path.mkdir(parents=True, exist_ok=True)
g = glob.glob(r + chkpt_basedir + "*")
c = g[0] + "/checkpoints"
latest_checkpoint = glob.glob(c + "/*")[-1]
# print(latest_checkpoint)
logs = glob.glob(g[0] + "/events*")[-1]
print(logs)
source_path = Path(logs)
# print("Would copy", source_path, dir_path)
# shutil.copy(source_path, dir_path)
make_image(latest_checkpoint, f"out/version_{i}")
make_image(latest_checkpoint, f"out/version_{i}b", color=False)
return H
if __name__ == "__main__":
D = get_exps("color_*", "_")
print(len(D), D)
Loading…
Cancel
Save