Browse Source

finalize plotting (6in, no roll)

main
Michael Pilosov, PhD 9 months ago
parent
commit
c909ad4372
  1. 21
      check.py
  2. 2
      makefile
  3. 197
      requirements-new.txt

21
check.py

@ -1,5 +1,6 @@
# import matplotlib.patches as patches # import matplotlib.patches as patches
from typing import Union from typing import Union
from pathlib import Path
import matplotlib.patches as patches import matplotlib.patches as patches
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -45,7 +46,8 @@ def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
): ):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) import yaml
M = ColorTransformerModel.load_from_checkpoint(ckpt, map_location=lambda storage, loc: storage)
else: else:
M = ckpt M = ckpt
@ -105,7 +107,6 @@ def plot_preds(
ax.axis("off") ax.axis("off")
radius = 1 radius = 1
ax.set_ylim(-radius, radius) ax.set_ylim(-radius, radius)
ax.set_xlim(-radius, radius)
# Overlay white circle # Overlay white circle
inner_radius = 1 / 3 inner_radius = 1 / 3
@ -131,20 +132,26 @@ if __name__ == "__main__":
# make the following accept a list of arguments # make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument( parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image." "--dpi", type=int, default=300, help="Resolution for saved image."
) )
parser.add_argument("--figsize", type=int, default=3, help="Figure size") parser.add_argument("--figsize", type=int, default=6, help="Figure size")
args = parser.parse_args() args = parser.parse_args()
versions = args.version versions = args.version
for v in versions: for v in versions:
name = f"out/v{v}" # name = f"out/v{v}"
studio = "colors-refactor-supervised"
# studio = "colors-refactor-unsupervised"
# studio = "colors-refactor-unsupervised-anchors"
Path(studio).mkdir(exist_ok=True, parents=True)
name = f"{studio}/v{v}"
# ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt" # ckpt = f"/teamspace/jobs/{name}/work/colors/lightning_logs/version_2/checkpoints/epoch=999-step=8000.ckpt"
ckpt_path = f"/teamspace/studios/colors-refactor-secondary/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path) ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0: if len(ckpt) > 0:
ckpt = ckpt[-1] ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}") print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2) create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2, roll=False)
else: else:
print(f"No checkpoint found for version {v}") print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,) # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)

2
makefile

@ -59,7 +59,7 @@ sort_umap:
python scripts/sortcolor.py -s umap --dpi 300 --seed 21 python scripts/sortcolor.py -s umap --dpi 300 --seed 21
parallel_umap: parallel_umap:
parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 1000) parallel -j 12 python scripts/sortcolor.py -s umap --dpi 300 --seed ::: $$(seq 1 100)
sort_lex: sort_lex:
python scripts/sortcolor.py -s lex --dpi 300 python scripts/sortcolor.py -s lex --dpi 300

197
requirements-new.txt

@ -0,0 +1,197 @@
absl-py==2.1.0
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
backoff==2.2.1
beautifulsoup4==4.12.3
bleach==6.1.0
boto3==1.34.42
botocore==1.34.42
cachetools==5.3.2
certifi==2024.2.2
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001684923/work
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.1
contourpy==1.2.0
cuda-python==12.3.0
cudf-cu12==23.12.1
cuml-cu12==23.12.0
cupy-cuda12x==13.0.0
cycler==0.12.1
dask==2023.11.0
dask-cuda==23.12.0
dask-cudf-cu12==23.12.0
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
distributed==2023.11.0
exceptiongroup==1.2.0
executing==2.0.1
fastapi==0.109.2
fastjsonschema==2.19.1
fastrlock==0.8.2
filelock==3.13.1
fire==0.5.0
fonttools==4.48.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.2.0
google-auth==2.27.0
google-auth-oauthlib==1.2.0
grpcio==1.60.1
h11==0.14.0
idna==3.6
importlib-metadata==7.0.1
ipykernel==6.26.0
ipython==8.17.2
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.3.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.2
jupyter_client==8.6.0
jupyter_core==5.7.1
jupyter_server==2.12.5
jupyter_server_terminals==0.5.2
jupyterlab==4.0.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.3
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
lightning==2.2.0.post0
lightning-cloud==0.5.64
lightning-utilities==0.10.1
lightning_sdk==0.0.18a0
llvmlite==0.40.1
locket==1.0.0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.2
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.5
nbclient==0.9.0
nbconvert==7.16.0
nbformat==5.9.2
nest-asyncio==1.6.0
networkx==3.2.1
notebook_shim==0.2.4
numba==0.57.1
numpy==1.24.4
nvtx==0.2.10
oauthlib==3.2.2
overrides==7.7.0
packaging==23.2
pandas==1.5.3
pandocfilters==1.5.1
parso==0.8.3
partd==1.4.1
pexpect==4.9.0
pillow==10.2.0
pkgconfig @ file:///home/conda/feedstock_root/build_artifacts/pkgconfig_1667031109701/work
platformdirs==4.2.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
protobuf==4.23.4
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==14.0.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.17.2
PyJWT==2.8.0
pylibraft-cu12==23.12.0
pynvml==11.4.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
python-multipart==0.0.9
pytorch-lightning==2.2.0
pytz==2024.1
pyvips @ file:///home/conda/feedstock_root/build_artifacts/pyvips_1695697837675/work
PyYAML==6.0.1
pyzmq==25.1.2
raft-dask-cu12==23.12.0
rapids-dask-dependency==23.12.1
referencing==0.33.0
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rmm-cu12==23.12.0
rpds-py==0.18.0
rsa==4.9
s3transfer==0.10.0
scikit-learn==1.3.2
scipy==1.11.4
Send2Trash==1.8.2
simple-term-menu==1.6.4
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.36.3
sympy==1.12
tblib==3.0.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminado==0.18.0
threadpoolctl==3.3.0
tinycss2==1.2.1
tomli==2.0.1
toolz==0.12.1
torch==2.1.2+cu121
torchaudio==2.1.2+cu121
torchmetrics==1.2.0
torchvision==0.16.2+cu121
tornado==6.4
tqdm==4.66.2
traitlets==5.14.1
treelite==3.9.1
treelite-runtime==3.9.1
triton==2.1.0
types-python-dateutil==2.8.19.20240106
typing_extensions==4.9.0
tzdata==2024.1
ucx-py-cu12==0.35.0
uri-template==1.3.0
urllib3==2.0.7
uvicorn==0.27.1
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.1
widgetsnbextension==4.0.10
yarl==1.9.4
zict==3.0.0
zipp==3.17.0
Loading…
Cancel
Save