Browse Source

finalize plotting (6in, no roll)

main
Michael Pilosov, PhD 9 months ago
parent
commit
c909ad4372
  1. 21
      check.py
  2. 2
      makefile
  3. 197
      requirements-new.txt

21
check.py

@ -1,5 +1,6 @@
# import matplotlib.patches as patches
from typing import Union
from pathlib import Path
import matplotlib.patches as patches
import matplotlib.pyplot as plt
@ -45,7 +46,8 @@ def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
import yaml
M = ColorTransformerModel.load_from_checkpoint(ckpt, map_location=lambda storage, loc: storage)
else:
M = ckpt
@ -105,7 +107,6 @@ def plot_preds(
ax.axis("off")
radius = 1
ax.set_ylim(-radius, radius)
ax.set_xlim(-radius, radius)
# Overlay white circle
inner_radius = 1 / 3
@ -131,20 +132,26 @@ if __name__ == "__main__":
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image."
"--dpi", type=int, default=300, help="Resolution for saved image."
)
parser.add_argument("--figsize", type=int, default=3, help="Figure size")
parser.add_argument("--figsize", type=int, default=6, help="Figure size")
args = parser.parse_args()
versions = args.version
for v in versions:
name = f"out/v{v}"
# name = f"out/v{v}"
studio = "colors-refactor-supervised"
# studio = "colors-refactor-unsupervised"
# studio = "colors-refactor-unsupervised-anchors"
Path(studio).mkdir(exist_ok=True, parents=True)
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt_path = f"/teamspace/studios/colors-refactor-secondary/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0:
ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2)
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False)
else:
print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)

2
makefile

@ -59,7 +59,7 @@ sort_umap:
python scripts/sortcolor.py -s umap --dpi 300 --seed 21
parallel_umap:
parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 1000)
parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
sort_lex:
python scripts/sortcolor.py -s lex --dpi 300

197
requirements-new.txt

@ -0,0 +1,197 @@
absl-py==2.1.0
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
backoff==2.2.1
beautifulsoup4==4.12.3
bleach==6.1.0
boto3==1.34.42
botocore==1.34.42
cachetools==5.3.2
certifi==2024.2.2
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001684923/work
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.1
contourpy==1.2.0
cuda-python==12.3.0
cudf-cu12==23.12.1
cuml-cu12==23.12.0
cupy-cuda12x==13.0.0
cycler==0.12.1
dask==2023.11.0
dask-cuda==23.12.0
dask-cudf-cu12==23.12.0
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
distributed==2023.11.0
exceptiongroup==1.2.0
executing==2.0.1
fastapi==0.109.2
fastjsonschema==2.19.1
fastrlock==0.8.2
filelock==3.13.1
fire==0.5.0
fonttools==4.48.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.2.0
google-auth==2.27.0
google-auth-oauthlib==1.2.0
grpcio==1.60.1
h11==0.14.0
idna==3.6
importlib-metadata==7.0.1
ipykernel==6.26.0
ipython==8.17.2
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.3.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.2
jupyter_client==8.6.0
jupyter_core==5.7.1
jupyter_server==2.12.5
jupyter_server_terminals==0.5.2
jupyterlab==4.0.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.3
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
lightning==2.2.0.post0
lightning-cloud==0.5.64
lightning-utilities==0.10.1
lightning_sdk==0.0.18a0
llvmlite==0.40.1
locket==1.0.0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.2
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.5
nbclient==0.9.0
nbconvert==7.16.0
nbformat==5.9.2
nest-asyncio==1.6.0
networkx==3.2.1
notebook_shim==0.2.4
numba==0.57.1
numpy==1.24.4
nvtx==0.2.10
oauthlib==3.2.2
overrides==7.7.0
packaging==23.2
pandas==1.5.3
pandocfilters==1.5.1
parso==0.8.3
partd==1.4.1
pexpect==4.9.0
pillow==10.2.0
pkgconfig @ file:///home/conda/feedstock_root/build_artifacts/pkgconfig_1667031109701/work
platformdirs==4.2.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
protobuf==4.23.4
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==14.0.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.17.2
PyJWT==2.8.0
pylibraft-cu12==23.12.0
pynvml==11.4.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
python-multipart==0.0.9
pytorch-lightning==2.2.0
pytz==2024.1
pyvips @ file:///home/conda/feedstock_root/build_artifacts/pyvips_1695697837675/work
PyYAML==6.0.1
pyzmq==25.1.2
raft-dask-cu12==23.12.0
rapids-dask-dependency==23.12.1
referencing==0.33.0
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rmm-cu12==23.12.0
rpds-py==0.18.0
rsa==4.9
s3transfer==0.10.0
scikit-learn==1.3.2
scipy==1.11.4
Send2Trash==1.8.2
simple-term-menu==1.6.4
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.36.3
sympy==1.12
tblib==3.0.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminado==0.18.0
threadpoolctl==3.3.0
tinycss2==1.2.1
tomli==2.0.1
toolz==0.12.1
torch==2.1.2+cu121
torchaudio==2.1.2+cu121
torchmetrics==1.2.0
torchvision==0.16.2+cu121
tornado==6.4
tqdm==4.66.2
traitlets==5.14.1
treelite==3.9.1
treelite-runtime==3.9.1
triton==2.1.0
types-python-dateutil==2.8.19.20240106
typing_extensions==4.9.0
tzdata==2024.1
ucx-py-cu12==0.35.0
uri-template==1.3.0
urllib3==2.0.7
uvicorn==0.27.1
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.1
widgetsnbextension==4.0.10
yarl==1.9.4
zict==3.0.0
zipp==3.17.0
Loading…
Cancel
Save