mm
11 months ago
3 changed files with 69 additions and 15 deletions
@ -1,3 +1,4 @@ |
|||||
lightning_logs/ |
lightning_logs/ |
||||
__pycache__/ |
__pycache__/ |
||||
|
out/ |
||||
.sw[opqr] |
.sw[opqr] |
||||
|
@ -1,25 +1,38 @@ |
|||||
import matplotlib.pyplot as plt |
import matplotlib.pyplot as plt |
||||
import numpy as np |
import numpy as np |
||||
|
import torch |
||||
|
|
||||
from dataloader import extract_colors |
from dataloader import extract_colors |
||||
from model import ColorTransformerModel |
from model import ColorTransformerModel |
||||
|
|
||||
name = "color_128_0.3_1.00e-06" |
def make_image(ckpt: str, fname: str, color=True): |
||||
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
||||
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
||||
|
|
||||
rgb_tensor, names = extract_colors() |
# preds = M(rgb_tensor) |
||||
preds = M(rgb_tensor) |
if not color: |
||||
rgb_values = rgb_tensor.detach().numpy() |
N = 949 |
||||
sorted_inds = np.argsort(preds.detach().numpy().ravel()) |
linear_space = torch.linspace(0, 1, N) |
||||
|
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) |
||||
|
else: |
||||
|
rgb_tensor, names = extract_colors() |
||||
|
|
||||
|
rgb_values = rgb_tensor.detach().numpy() |
||||
|
preds = M(rgb_tensor) |
||||
|
sorted_inds = np.argsort(preds.detach().numpy().ravel()) |
||||
|
|
||||
fig, ax = plt.subplots(figsize=(10, 5)) |
fig, ax = plt.subplots(figsize=(10, 5)) |
||||
for i in range(len(sorted_inds)): |
for i in range(len(sorted_inds)): |
||||
idx = sorted_inds[i] |
idx = sorted_inds[i] |
||||
color = rgb_values[idx] |
color = rgb_values[idx] |
||||
ax.vlines(4 * i, ymin=0, ymax=1, lw=1, colors=names[idx]) |
ax.vlines(i, ymin=0, ymax=1, lw=0.1, colors=color, antialiased=False, alpha=0.5) |
||||
ax.axis("off") |
ax.axis("off") |
||||
# ax.axis("square") |
# ax.axis("square") |
||||
|
|
||||
plt.savefig(f"{name}.png", dpi=300) |
plt.savefig(f"{fname}.png", dpi=300) |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
|
||||
|
name = "color_128_0.3_1.00e-06" |
||||
|
ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" |
||||
|
make_image(ckpt, fname=name) |
||||
|
@ -0,0 +1,40 @@ |
|||||
|
import glob |
||||
|
from pathlib import Path |
||||
|
import shutil |
||||
|
from check import make_image |
||||
|
|
||||
|
|
||||
|
def get_exps(pattern: str, splitter: str = "_"): |
||||
|
basedir = "/teamspace/jobs/" |
||||
|
chkpt_basedir = "/work/colors/lightning_logs/" |
||||
|
location = basedir + pattern |
||||
|
res = glob.glob(location) |
||||
|
location = location.replace('*', '') |
||||
|
H = [] # hyperparams used |
||||
|
# print(res) |
||||
|
for r in res: |
||||
|
d = r.replace(location, '').split(splitter) |
||||
|
d = list(float(_d) for _d in d) |
||||
|
d[0] = int(d[0]) |
||||
|
H.append(d) |
||||
|
for i, r in enumerate(res): |
||||
|
dir_path = Path(f"/teamspace/studios/this_studio/colors/lightning_logs/version_{i}/") |
||||
|
dir_path.mkdir(parents=True, exist_ok=True) |
||||
|
g = glob.glob(r + chkpt_basedir + "*") |
||||
|
c = g[0] + "/checkpoints" |
||||
|
latest_checkpoint = glob.glob(c + "/*")[-1] |
||||
|
# print(latest_checkpoint) |
||||
|
logs = glob.glob(g[0] + "/events*")[-1] |
||||
|
print(logs) |
||||
|
source_path = Path(logs) |
||||
|
# print("Would copy", source_path, dir_path) |
||||
|
# shutil.copy(source_path, dir_path) |
||||
|
make_image(latest_checkpoint, f"out/version_{i}") |
||||
|
make_image(latest_checkpoint, f"out/version_{i}b", color=False) |
||||
|
|
||||
|
return H |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
D = get_exps("color_*", "_") |
||||
|
print(len(D), D) |
Loading…
Reference in new issue