Browse Source

type hint, rename vars

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
3adcc9779a
  1. 15
      check.py

15
check.py

@ -1,4 +1,6 @@
# import matplotlib.patches as patches
from typing import Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
@ -39,16 +41,19 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs):
plt.savefig(f"{fname}.png", **kwargs)
def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs):
def create_circle(
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs
):
if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt)
else:
M = ckpt
rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor.to(M.device))
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs)
xkcd_colors, _ = extract_colors()
xkcd_colors = preprocess_data(xkcd_colors).to(M.device)
preds = M(xkcd_colors)
rgb_array = xkcd_colors.detach().cpu().numpy()
plot_preds(preds, rgb_array, fname=fname, **kwargs)
def plot_preds(

Loading…
Cancel
Save