diff --git a/check.py b/check.py index 249e35e..467234a 100644 --- a/check.py +++ b/check.py @@ -1,4 +1,6 @@ # import matplotlib.patches as patches +from typing import Union + import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np @@ -39,16 +41,19 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs): plt.savefig(f"{fname}.png", **kwargs) -def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs): +def create_circle( + ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs +): if isinstance(ckpt, str): M = ColorTransformerModel.load_from_checkpoint(ckpt) else: M = ckpt - rgb_tensor, _ = extract_colors() - rgb_tensor = preprocess_data(rgb_tensor) - preds = M(rgb_tensor.to(M.device)) - plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs) + xkcd_colors, _ = extract_colors() + xkcd_colors = preprocess_data(xkcd_colors).to(M.device) + preds = M(xkcd_colors) + rgb_array = xkcd_colors.detach().cpu().numpy() + plot_preds(preds, rgb_array, fname=fname, **kwargs) def plot_preds(