|
|
@ -1,4 +1,6 @@ |
|
|
|
# import matplotlib.patches as patches |
|
|
|
import pickle |
|
|
|
from multiprocessing import Process |
|
|
|
from pathlib import Path |
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
|
@ -42,11 +44,34 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs): |
|
|
|
plt.savefig(f"{fname}.png", **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
# def create_circle( |
|
|
|
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs |
|
|
|
# ): |
|
|
|
# if isinstance(ckpt, str): |
|
|
|
|
|
|
|
# M = ColorTransformerModel.load_from_checkpoint( |
|
|
|
# ckpt, map_location=lambda storage, loc: storage |
|
|
|
# ) |
|
|
|
# else: |
|
|
|
# M = ckpt |
|
|
|
|
|
|
|
# xkcd_colors, _ = extract_colors() |
|
|
|
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
|
# preds = M(xkcd_colors).detach().cpu().numpy() |
|
|
|
# rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
|
# plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def plot_preds_serialized(serialized_data, fname, **kwargs): |
|
|
|
# Deserialize the data |
|
|
|
preds, rgb_array = pickle.loads(serialized_data) |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def create_circle( |
|
|
|
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs |
|
|
|
): |
|
|
|
if isinstance(ckpt, str): |
|
|
|
|
|
|
|
M = ColorTransformerModel.load_from_checkpoint( |
|
|
|
ckpt, map_location=lambda storage, loc: storage |
|
|
|
) |
|
|
@ -57,7 +82,16 @@ def create_circle( |
|
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
|
preds = M(xkcd_colors).detach().cpu().numpy() |
|
|
|
rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
# Serialize the data |
|
|
|
serialized_data = pickle.dumps((preds, rgb_array)) |
|
|
|
|
|
|
|
# Run plot_preds_serialized function in a separate process |
|
|
|
p = Process( |
|
|
|
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs |
|
|
|
) |
|
|
|
p.start() |
|
|
|
return p |
|
|
|
|
|
|
|
|
|
|
|
def plot_preds( |
|
|
@ -65,7 +99,7 @@ def plot_preds( |
|
|
|
rgb_values, |
|
|
|
fname: str, |
|
|
|
roll: bool = False, |
|
|
|
inner_radius: float = 1 / 3, |
|
|
|
radius: float = 1 / 2, |
|
|
|
dpi: int = 300, |
|
|
|
figsize: Tuple[float] = (6, 6), |
|
|
|
fsize: int = 0, |
|
|
@ -113,12 +147,11 @@ def plot_preds( |
|
|
|
ax.set_yticks([]) |
|
|
|
ax.set_aspect("equal") |
|
|
|
ax.axis("off") |
|
|
|
radius = 1 |
|
|
|
ax.set_ylim(0, radius) |
|
|
|
ax.set_ylim(0, 1) # implicit outer radius of 1 |
|
|
|
|
|
|
|
# Overlay white circle |
|
|
|
circle = patches.Circle( |
|
|
|
(0, 0), inner_radius, transform=ax.transData._b, color="white", zorder=2 |
|
|
|
(0, 0), radius, transform=ax.transData._b, color="white", zorder=2 |
|
|
|
) |
|
|
|
ax.add_patch(circle) |
|
|
|
|
|
|
|