Browse Source

plotting updates, poster

plotting-unify
Michael Pilosov, PhD 9 months ago
parent
commit
09a2333297
  1. 2
      callbacks.py
  2. 71
      check.py
  3. 14
      datamodule.py
  4. 3
      makefile
  5. 6
      scripts/color_poster.py
  6. 44
      scripts/install_ptmono.sh
  7. 7
      scripts/sortcolor.py

2
callbacks.py

@ -4,7 +4,7 @@ from pathlib import Path
from lightning import Callback from lightning import Callback
from check import create_circle from check import create_circle_nonblocking as create_circle
class SaveImageCallback(Callback): class SaveImageCallback(Callback):

71
check.py

@ -15,21 +15,21 @@ from model import ColorTransformerModel
# import matplotlib.colors as mcolors # import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True, **kwargs): def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor) # preds = M(rgb_tensor)
if not color: if color is False: # black and white ordering...
N = 949 N = 949
linear_space = torch.linspace(0, 1, N) linear_space = torch.linspace(0, 1, N)
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3)
else: else:
rgb_tensor, names = extract_colors() rgb_tensor, names = extract_colors()
rgb_values = rgb_tensor.detach().numpy() rgb_tensor = preprocess_data(rgb_tensor).to(M.device)
rgb_tensor = preprocess_data(rgb_tensor) preds = M(rgb_tensor).detach().cpu().numpy()
preds = M(rgb_tensor) rgb_values = rgb_tensor.detach().cpu().numpy()
sorted_inds = np.argsort(preds.detach().numpy().ravel()) sorted_inds = np.argsort(preds.ravel())
fig, ax = plt.subplots() fig, ax = plt.subplots()
for i in range(len(sorted_inds)): for i in range(len(sorted_inds)):
@ -44,33 +44,8 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs):
plt.savefig(f"{fname}.png", **kwargs) plt.savefig(f"{fname}.png", **kwargs)
# def create_circle( def do_inference(ckpt: Union[str, ColorTransformerModel]):
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
# ):
# if isinstance(ckpt, str):
# M = ColorTransformerModel.load_from_checkpoint(
# ckpt, map_location=lambda storage, loc: storage
# )
# else:
# M = ckpt
# xkcd_colors, _ = extract_colors()
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
# preds = M(xkcd_colors).detach().cpu().numpy()
# rgb_array = xkcd_colors.detach().cpu().numpy()
# plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint( M = ColorTransformerModel.load_from_checkpoint(
ckpt, map_location=lambda storage, loc: storage ckpt, map_location=lambda storage, loc: storage
@ -82,13 +57,31 @@ def create_circle(
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors).detach().cpu().numpy() preds = M(xkcd_colors).detach().cpu().numpy()
rgb_array = xkcd_colors.detach().cpu().numpy() rgb_array = xkcd_colors.detach().cpu().numpy()
return preds, rgb_array
def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs):
preds, rgb_array = do_inference(ckpt)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def _plot_preds_serialized(serialized_data, fname, **kwargs):
# Deserialize the data
preds, rgb_array = pickle.loads(serialized_data)
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def create_circle_nonblocking(
ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs
):
preds, rgb_array = do_inference(ckpt)
# Serialize the data # Serialize the data
serialized_data = pickle.dumps((preds, rgb_array)) serialized_data = pickle.dumps((preds, rgb_array))
# Run plot_preds_serialized function in a separate process # Run _plot_preds_serialized function in a separate process
p = Process( p = Process(
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs
) )
p.start() p.start()
return p return p
@ -96,7 +89,7 @@ def create_circle(
def plot_preds( def plot_preds(
preds: np.ndarray, preds: np.ndarray,
rgb_values, rgb_values: np.ndarray,
fname: str, fname: str,
roll: bool = False, roll: bool = False,
radius: float = 1 / 2, radius: float = 1 / 2,
@ -173,12 +166,10 @@ def plot_preds(
if __name__ == "__main__": if __name__ == "__main__":
# name = "color_128_0.3_1.00e-06"
import argparse import argparse
import glob import glob
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument( parser.add_argument(
"--dpi", type=int, default=300, help="Resolution for saved image." "--dpi", type=int, default=300, help="Resolution for saved image."
@ -186,7 +177,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--studio", "--studio",
type=str, type=str,
default="this_studio", default=["this_studio"],
nargs="+", nargs="+",
help="Checkpoint studio name.", help="Checkpoint studio name.",
) )
@ -201,8 +192,8 @@ if __name__ == "__main__":
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" # ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt"
ckpt = glob.glob(ckpt_path) ckpt = glob.glob(ckpt_path)
if len(ckpt) > 0: if len(ckpt) > 0: # get latest checkpoint
ckpt = ckpt[-1] ckpt = ckpt[-1] # TODO: allow specification via CLI
print(f"Generating image for checkpoint: {ckpt}") print(f"Generating image for checkpoint: {ckpt}")
create_circle( create_circle(
ckpt, ckpt,

14
datamodule.py

@ -35,14 +35,18 @@ class ColorDataModule(L.LightningDataModule):
return [(c, cls.get_hue(c)) for c in train_rgb] return [(c, cls.get_hue(c)) for c in train_rgb]
@classmethod @classmethod
def get_xkcd_colors(cls): def get_xkcd_colors(cls, label="hues"):
rgb_tensor, xkcd_color_names = extract_colors() rgb_tensor, xkcd_color_names = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor, skip=True) rgb_tensor = preprocess_data(rgb_tensor, skip=True)
# return [ if label == "names":
# (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", "")) return [
# for i in range(len(rgb_tensor)) (rgb_tensor[i], xkcd_color_names[i].replace("xkcd:", ""))
# ] for i in range(len(rgb_tensor))
]
if label == "hues":
return [(c, cls.get_hue(c)) for c in rgb_tensor] return [(c, cls.get_hue(c)) for c in rgb_tensor]
else:
raise ValueError("Please specify `label` as one of ['hues', 'names'].")
def setup(self, stage: str): def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders # Assign train/val datasets for use in dataloaders

3
makefile

@ -72,6 +72,9 @@ sort_lex:
sort_hsv: sort_hsv:
python scripts/sortcolor.py -s hsv --dpi 300 python scripts/sortcolor.py -s hsv --dpi 300
poster: sort_lex sort_hsv
python scripts/color_poster.py -k hsv lex --rows 119
clean: clean:
rm -rf lightning_logs rm -rf lightning_logs
rm -rf .lr_find_*.ckpt rm -rf .lr_find_*.ckpt

6
color_poster.py → scripts/color_poster.py

@ -168,14 +168,14 @@ if __name__ == "__main__":
"--rows", type=int, default=73, help="Number of entries per column" "--rows", type=int, default=73, help="Number of entries per column"
) )
parser.add_argument( parser.add_argument(
"--dir", type=str, default="~/color/out", help="Directory to save images" "--dir", type=str, default="~/out_sortcolors", help="Directory to save images"
) )
parser.add_argument( parser.add_argument(
"-k", "-k",
"--kind", "--kind",
type=str, type=str,
nargs="+", nargs="+",
default=["hsv", "lex", "lab", "umap"], default=["hsv", "lex"],
help="Kinds of sorting", help="Kinds of sorting",
) )
parser.add_argument( parser.add_argument(
@ -209,7 +209,7 @@ if __name__ == "__main__":
# KIND = "hsv" # choose from umap, hsv # KIND = "hsv" # choose from umap, hsv
for KIND in KINDS: for KIND in KINDS:
colors = list(mcolors.XKCD_COLORS.keys()) colors = list(mcolors.XKCD_COLORS.keys())
sorted_indices = np.load(f"scripts/{KIND}_sorted_indices.npy") sorted_indices = np.load(f"/teamspace/studios/this_studio/out_sortcolors/{KIND}/sorted_indices.npy")
sorted_colors = [colors[idx] for idx in sorted_indices] sorted_colors = [colors[idx] for idx in sorted_indices]
colors = sorted_colors colors = sorted_colors

44
scripts/install_ptmono.sh

@ -0,0 +1,44 @@
#!/bin/bash
# Source: https://blog.programster.org/ubuntu-install-pt-mono-font
# Installs PT Mono font onto Ubuntu 12.04 or 14.04 Systems
# Check to make sure the user has the unzip package installed
export NO_UNZIP=$(apt-cache policy unzip | grep "Installed: (none)" | wc -l)
# Result will be 1 if it is NOT installed
if [ "$NO_UNZIP" = "0" ]; then
export TEMP_DIR='temp-technostu-script'
cd ~
mkdir $TEMP_DIR
cd $TEMP_DIR
# Download PT mono from google fonts
export FONT_URL="http://www.google.com/fonts/download"
export FONT_URL="$FONT_URL?kit=7qsh9BNBJbZ6khIbS3ZpfKCWcynf_cDxXwCLxiixG1c"
wget --content-disposition "$FONT_URL"
# Create a PT_Mono directory which we will copy across into the fonts directory.
mkdir PT_Mono
mv PT_Mono.zip PT_Mono/.
cd PT_Mono
unzip PT_Mono.zip
rm PT_Mono.zip
cd ..
sudo mv PT_Mono /usr/share/fonts/truetype/.
# Re-cache the fonts
echo 'Re-caching fonts...'
sudo fc-cache -fv
# cleanup
cd ~
sudo rm -rf $TEMP_DIR
echo 'done!'
else
# User doesnt have unzip installed, tell them how to install it
echo 'You need to install unzip for this to work: try '
echo '"sudo apt-get install unzip"'
fi

7
scripts/sortcolor.py

@ -82,10 +82,11 @@ def peano_curve(n):
if KIND in ("lex", "alpha", "abc"): if KIND in ("lex", "alpha", "abc"):
KIND = "lex"
preds = np.array(colors) preds = np.array(colors)
elif KIND == "umap": elif KIND == "umap":
PDIR = f"scripts/{KIND}-prod" PDIR = f"/teamspace/studios/this_studio/out_sortcolors/{KIND}"
Path(PDIR).mkdir(parents=True, exist_ok=True) Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{SEED:06d}.npy" file_path = f"{PDIR}/{SEED:06d}.npy"
if Path(file_path).exists(): if Path(file_path).exists():
@ -147,9 +148,9 @@ else:
raise ValueError(f"Unknown kind: {KIND}") raise ValueError(f"Unknown kind: {KIND}")
PDIR = f"scripts" PDIR = f"/teamspace/studios/this_studio/out_sortcolors/{KIND}"
Path(PDIR).mkdir(parents=True, exist_ok=True) Path(PDIR).mkdir(parents=True, exist_ok=True)
file_path = f"{PDIR}/{KIND}_sorted_indices.npy" file_path = f"{PDIR}/sorted_indices.npy"
# Sort colors by the 1D representation # Sort colors by the 1D representation
sorted_indices = np.argsort(preds) sorted_indices = np.argsort(preds)
sorted_colors = [colors[i] for i in sorted_indices] sorted_colors = [colors[i] for i in sorted_indices]

Loading…
Cancel
Save