diff --git a/check.py b/check.py index 56b458f..8fc616c 100644 --- a/check.py +++ b/check.py @@ -1,5 +1,6 @@ # import matplotlib.patches as patches from typing import Union +from pathlib import Path import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -45,7 +46,8 @@ def create_circle( ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs ): if isinstance(ckpt, str): - M = ColorTransformerModel.load_from_checkpoint(ckpt) + import yaml + M = ColorTransformerModel.load_from_checkpoint(ckpt, map_location=lambda storage, loc: storage) else: M = ckpt @@ -105,7 +107,6 @@ def plot_preds( ax.axis("off") radius = 1 ax.set_ylim(-radius, radius) - ax.set_xlim(-radius, radius) # Overlay white circle inner_radius = 1 / 3 @@ -131,20 +132,26 @@ if __name__ == "__main__": # make the following accept a list of arguments parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument( - "--dpi", type=int, default=150, help="Resolution for saved image." + "--dpi", type=int, default=300, help="Resolution for saved image." ) - parser.add_argument("--figsize", type=int, default=3, help="Figure size") + parser.add_argument("--figsize", type=int, default=6, help="Figure size") args = parser.parse_args() versions = args.version for v in versions: - name = f"out/v{v}" + # name = f"out/v{v}" + studio = "colors-refactor-supervised" + # studio = "colors-refactor-unsupervised" + # studio = "colors-refactor-unsupervised-anchors" + Path(studio).mkdir(exist_ok=True, parents=True) + name = f"{studio}/v{v}" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" - ckpt_path = f"/teamspace/studios/colors-refactor-secondary/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" + # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" + ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt = glob.glob(ckpt_path) if len(ckpt) > 0: ckpt = ckpt[-1] print(f"Generating image for checkpoint: {ckpt}") - create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2) + create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False) else: print(f"No checkpoint found for version {v}") # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,) diff --git a/makefile b/makefile index 9228459..5751b77 100644 --- a/makefile +++ b/makefile @@ -59,7 +59,7 @@ sort_umap: python scripts/sortcolor.py -s umap --dpi 300 --seed 21 parallel_umap: - parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 1000) + parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100) sort_lex: python scripts/sortcolor.py -s lex --dpi 300 diff --git a/requirements-new.txt b/requirements-new.txt new file mode 100644 index 0000000..0c0fe66 --- /dev/null +++ b/requirements-new.txt @@ -0,0 +1,197 @@ +absl-py==2.1.0 +aiohttp==3.9.3 +aiosignal==1.3.1 +annotated-types==0.6.0 +anyio==4.2.0 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.2.0 +Babel==2.14.0 +backoff==2.2.1 +beautifulsoup4==4.12.3 +bleach==6.1.0 +boto3==1.34.42 +botocore==1.34.42 +cachetools==5.3.2 +certifi==2024.2.2 +cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001684923/work +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +comm==0.2.1 +contourpy==1.2.0 +cuda-python==12.3.0 +cudf-cu12==23.12.1 +cuml-cu12==23.12.0 +cupy-cuda12x==13.0.0 +cycler==0.12.1 +dask==2023.11.0 +dask-cuda==23.12.0 +dask-cudf-cu12==23.12.0 +debugpy==1.8.1 +decorator==5.1.1 +defusedxml==0.7.1 +distributed==2023.11.0 +exceptiongroup==1.2.0 +executing==2.0.1 +fastapi==0.109.2 +fastjsonschema==2.19.1 +fastrlock==0.8.2 +filelock==3.13.1 +fire==0.5.0 +fonttools==4.48.1 +fqdn==1.5.1 +frozenlist==1.4.1 +fsspec==2024.2.0 +google-auth==2.27.0 +google-auth-oauthlib==1.2.0 +grpcio==1.60.1 +h11==0.14.0 +idna==3.6 +importlib-metadata==7.0.1 +ipykernel==6.26.0 +ipython==8.17.2 +ipywidgets==8.1.1 +isoduration==20.11.0 +jedi==0.19.1 +Jinja2==3.1.3 +jmespath==1.0.1 +joblib==1.3.2 +json5==0.9.14 +jsonpointer==2.4 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +jupyter-events==0.9.0 +jupyter-lsp==2.2.2 +jupyter_client==8.6.0 +jupyter_core==5.7.1 +jupyter_server==2.12.5 +jupyter_server_terminals==0.5.2 +jupyterlab==4.0.6 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.25.3 +jupyterlab_widgets==3.0.10 +kiwisolver==1.4.5 +lightning==2.2.0.post0 +lightning-cloud==0.5.64 +lightning-utilities==0.10.1 +lightning_sdk==0.0.18a0 +llvmlite==0.40.1 +locket==1.0.0 +Markdown==3.5.2 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.8.2 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mistune==3.0.2 +mpmath==1.3.0 +msgpack==1.0.7 +multidict==6.0.5 +nbclient==0.9.0 +nbconvert==7.16.0 +nbformat==5.9.2 +nest-asyncio==1.6.0 +networkx==3.2.1 +notebook_shim==0.2.4 +numba==0.57.1 +numpy==1.24.4 +nvtx==0.2.10 +oauthlib==3.2.2 +overrides==7.7.0 +packaging==23.2 +pandas==1.5.3 +pandocfilters==1.5.1 +parso==0.8.3 +partd==1.4.1 +pexpect==4.9.0 +pillow==10.2.0 +pkgconfig @ file:///home/conda/feedstock_root/build_artifacts/pkgconfig_1667031109701/work +platformdirs==4.2.0 +prometheus_client==0.20.0 +prompt-toolkit==3.0.43 +protobuf==4.23.4 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyarrow==14.0.2 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work +pydantic==2.6.1 +pydantic_core==2.16.2 +Pygments==2.17.2 +PyJWT==2.8.0 +pylibraft-cu12==23.12.0 +pynvml==11.4.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-json-logger==2.0.7 +python-multipart==0.0.9 +pytorch-lightning==2.2.0 +pytz==2024.1 +pyvips @ file:///home/conda/feedstock_root/build_artifacts/pyvips_1695697837675/work +PyYAML==6.0.1 +pyzmq==25.1.2 +raft-dask-cu12==23.12.0 +rapids-dask-dependency==23.12.1 +referencing==0.33.0 +requests==2.31.0 +requests-oauthlib==1.3.1 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.0 +rmm-cu12==23.12.0 +rpds-py==0.18.0 +rsa==4.9 +s3transfer==0.10.0 +scikit-learn==1.3.2 +scipy==1.11.4 +Send2Trash==1.8.2 +simple-term-menu==1.6.4 +six==1.16.0 +sniffio==1.3.0 +sortedcontainers==2.4.0 +soupsieve==2.5 +stack-data==0.6.3 +starlette==0.36.3 +sympy==1.12 +tblib==3.0.0 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +termcolor==2.4.0 +terminado==0.18.0 +threadpoolctl==3.3.0 +tinycss2==1.2.1 +tomli==2.0.1 +toolz==0.12.1 +torch==2.1.2+cu121 +torchaudio==2.1.2+cu121 +torchmetrics==1.2.0 +torchvision==0.16.2+cu121 +tornado==6.4 +tqdm==4.66.2 +traitlets==5.14.1 +treelite==3.9.1 +treelite-runtime==3.9.1 +triton==2.1.0 +types-python-dateutil==2.8.19.20240106 +typing_extensions==4.9.0 +tzdata==2024.1 +ucx-py-cu12==0.35.0 +uri-template==1.3.0 +urllib3==2.0.7 +uvicorn==0.27.1 +wcwidth==0.2.13 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.7.0 +Werkzeug==3.0.1 +widgetsnbextension==4.0.10 +yarl==1.9.4 +zict==3.0.0 +zipp==3.17.0