|
|
@ -1,4 +1,6 @@ |
|
|
|
# import matplotlib.patches as patches |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import matplotlib.patches as patches |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
@ -39,16 +41,19 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs): |
|
|
|
plt.savefig(f"{fname}.png", **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs): |
|
|
|
def create_circle( |
|
|
|
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs |
|
|
|
): |
|
|
|
if isinstance(ckpt, str): |
|
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
|
else: |
|
|
|
M = ckpt |
|
|
|
|
|
|
|
rgb_tensor, _ = extract_colors() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
preds = M(rgb_tensor.to(M.device)) |
|
|
|
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs) |
|
|
|
xkcd_colors, _ = extract_colors() |
|
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
|
preds = M(xkcd_colors) |
|
|
|
rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def plot_preds( |
|
|
|