Browse Source

type hint, rename vars

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
3adcc9779a
  1. 15
      check.py

15
check.py

@ -1,4 +1,6 @@
# import matplotlib.patches as patches # import matplotlib.patches as patches
from typing import Union
import matplotlib.patches as patches import matplotlib.patches as patches
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -39,16 +41,19 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs):
plt.savefig(f"{fname}.png", **kwargs) plt.savefig(f"{fname}.png", **kwargs)
def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs): def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
else: else:
M = ckpt M = ckpt
rgb_tensor, _ = extract_colors() xkcd_colors, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor) xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(rgb_tensor.to(M.device)) preds = M(xkcd_colors)
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs) rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds( def plot_preds(

Loading…
Cancel
Save