diff --git a/.gitignore b/.gitignore index 651f70e..fdc6906 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ out/ *.png .sw[opqr] *.tar.gz +.pat diff --git a/check.py b/check.py index 7f13648..bd38e4e 100644 --- a/check.py +++ b/check.py @@ -6,9 +6,6 @@ import torch from dataloader import extract_colors, preprocess_data from model import ColorTransformerModel -import numpy as np -import matplotlib.pyplot as plt - # import matplotlib.colors as mcolors diff --git a/makefile b/makefile index 950c38b..3c7e957 100644 --- a/makefile +++ b/makefile @@ -1,7 +1,7 @@ lint: black . isort --profile=black . - flake8 --ignore E501 . + flake8 --ignore E501,W503 . test: python main.py --alpha 1 --lr 1e-4 --max_epochs 500 diff --git a/scrape.py b/scrape.py index f19a843..7d7cb1c 100644 --- a/scrape.py +++ b/scrape.py @@ -5,7 +5,7 @@ from pathlib import Path from check import make_image -def get_exps(pattern: str, splitter: str = "_"): +def get_exps(pattern: str, splitter: str = "_", dry_run: bool = True): basedir = "/teamspace/jobs/" chkpt_basedir = "/work/colors/lightning_logs/" location = basedir + pattern @@ -24,17 +24,19 @@ def get_exps(pattern: str, splitter: str = "_"): ) dir_path.mkdir(parents=True, exist_ok=True) g = glob.glob(r + chkpt_basedir + "*") - c = g[0] + "/checkpoints" - latest_checkpoint = glob.glob(c + "/*")[-1] - # print(latest_checkpoint) logs = glob.glob(g[0] + "/events*")[-1] - print(logs) source_path = Path(logs) - print("Would copy", source_path, dir_path) - # shutil.copy(source_path, dir_path) - # make_image(latest_checkpoint, f"out/version_{i}") - # make_image(latest_checkpoint, f"out/version_{i}b", color=False) - + print(logs) + if not dry_run: + c = g[0] + "/checkpoints" + latest_checkpoint = glob.glob(c + "/*")[-1] + print(latest_checkpoint) + if not dry_run: + shutil.copy(source_path, dir_path) + make_image(latest_checkpoint, f"out/version_{i}") + # make_image(latest_checkpoint, f"out/version_{i}b", color=False) + else: + print("Would copy", source_path, dir_path) return H diff --git a/search.py b/search.py index f73f8fc..7a9b3cc 100644 --- a/search.py +++ b/search.py @@ -3,7 +3,7 @@ import sys from random import sample import numpy as np -from lightning_sdk import Machine, Studio +from lightning_sdk import Machine, Studio # noqa: F401 NUM_JOBS = 64