|
|
@ -15,21 +15,21 @@ from model import ColorTransformerModel |
|
|
|
# import matplotlib.colors as mcolors |
|
|
|
|
|
|
|
|
|
|
|
def make_image(ckpt: str, fname: str, color=True, **kwargs): |
|
|
|
def create_rectangle(ckpt: str, fname: str, color: bool = True, **kwargs): |
|
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
|
|
|
|
|
# preds = M(rgb_tensor) |
|
|
|
if not color: |
|
|
|
if color is False: # black and white ordering... |
|
|
|
N = 949 |
|
|
|
linear_space = torch.linspace(0, 1, N) |
|
|
|
rgb_tensor = linear_space.unsqueeze(1).repeat(1, 3) |
|
|
|
else: |
|
|
|
rgb_tensor, names = extract_colors() |
|
|
|
|
|
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
|
preds = M(rgb_tensor) |
|
|
|
sorted_inds = np.argsort(preds.detach().numpy().ravel()) |
|
|
|
rgb_tensor = preprocess_data(rgb_tensor).to(M.device) |
|
|
|
preds = M(rgb_tensor).detach().cpu().numpy() |
|
|
|
rgb_values = rgb_tensor.detach().cpu().numpy() |
|
|
|
sorted_inds = np.argsort(preds.ravel()) |
|
|
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
for i in range(len(sorted_inds)): |
|
|
@ -44,33 +44,8 @@ def make_image(ckpt: str, fname: str, color=True, **kwargs): |
|
|
|
plt.savefig(f"{fname}.png", **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
# def create_circle( |
|
|
|
# ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs |
|
|
|
# ): |
|
|
|
# if isinstance(ckpt, str): |
|
|
|
|
|
|
|
# M = ColorTransformerModel.load_from_checkpoint( |
|
|
|
# ckpt, map_location=lambda storage, loc: storage |
|
|
|
# ) |
|
|
|
# else: |
|
|
|
# M = ckpt |
|
|
|
|
|
|
|
# xkcd_colors, _ = extract_colors() |
|
|
|
# xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
|
# preds = M(xkcd_colors).detach().cpu().numpy() |
|
|
|
# rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
|
# plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
def do_inference(ckpt: Union[str, ColorTransformerModel]): |
|
|
|
|
|
|
|
|
|
|
|
def plot_preds_serialized(serialized_data, fname, **kwargs): |
|
|
|
# Deserialize the data |
|
|
|
preds, rgb_array = pickle.loads(serialized_data) |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def create_circle( |
|
|
|
ckpt: Union[str, ColorTransformerModel], fname: str, skip: bool = True, **kwargs |
|
|
|
): |
|
|
|
if isinstance(ckpt, str): |
|
|
|
M = ColorTransformerModel.load_from_checkpoint( |
|
|
|
ckpt, map_location=lambda storage, loc: storage |
|
|
@ -82,13 +57,31 @@ def create_circle( |
|
|
|
xkcd_colors = preprocess_data(xkcd_colors).to(M.device) |
|
|
|
preds = M(xkcd_colors).detach().cpu().numpy() |
|
|
|
rgb_array = xkcd_colors.detach().cpu().numpy() |
|
|
|
return preds, rgb_array |
|
|
|
|
|
|
|
|
|
|
|
def create_circle(ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs): |
|
|
|
preds, rgb_array = do_inference(ckpt) |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def _plot_preds_serialized(serialized_data, fname, **kwargs): |
|
|
|
# Deserialize the data |
|
|
|
preds, rgb_array = pickle.loads(serialized_data) |
|
|
|
plot_preds(preds, rgb_array, fname=fname, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def create_circle_nonblocking( |
|
|
|
ckpt: Union[str, ColorTransformerModel], fname: str, **kwargs |
|
|
|
): |
|
|
|
preds, rgb_array = do_inference(ckpt) |
|
|
|
|
|
|
|
# Serialize the data |
|
|
|
serialized_data = pickle.dumps((preds, rgb_array)) |
|
|
|
|
|
|
|
# Run plot_preds_serialized function in a separate process |
|
|
|
# Run _plot_preds_serialized function in a separate process |
|
|
|
p = Process( |
|
|
|
target=plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs |
|
|
|
target=_plot_preds_serialized, args=(serialized_data, fname), kwargs=kwargs |
|
|
|
) |
|
|
|
p.start() |
|
|
|
return p |
|
|
@ -96,7 +89,7 @@ def create_circle( |
|
|
|
|
|
|
|
def plot_preds( |
|
|
|
preds: np.ndarray, |
|
|
|
rgb_values, |
|
|
|
rgb_values: np.ndarray, |
|
|
|
fname: str, |
|
|
|
roll: bool = False, |
|
|
|
radius: float = 1 / 2, |
|
|
@ -173,12 +166,10 @@ def plot_preds( |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
# name = "color_128_0.3_1.00e-06" |
|
|
|
import argparse |
|
|
|
import glob |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
# make the following accept a list of arguments |
|
|
|
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0]) |
|
|
|
parser.add_argument( |
|
|
|
"--dpi", type=int, default=300, help="Resolution for saved image." |
|
|
@ -186,7 +177,7 @@ if __name__ == "__main__": |
|
|
|
parser.add_argument( |
|
|
|
"--studio", |
|
|
|
type=str, |
|
|
|
default="this_studio", |
|
|
|
default=["this_studio"], |
|
|
|
nargs="+", |
|
|
|
help="Checkpoint studio name.", |
|
|
|
) |
|
|
@ -201,8 +192,8 @@ if __name__ == "__main__": |
|
|
|
# ckpt_path = f"/teamspace/studios/this_studio/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" |
|
|
|
ckpt_path = f"/teamspace/studios/{studio}/colors/lightning_logs/version_{v}/checkpoints/*.ckpt" |
|
|
|
ckpt = glob.glob(ckpt_path) |
|
|
|
if len(ckpt) > 0: |
|
|
|
ckpt = ckpt[-1] |
|
|
|
if len(ckpt) > 0: # get latest checkpoint |
|
|
|
ckpt = ckpt[-1] # TODO: allow specification via CLI |
|
|
|
print(f"Generating image for checkpoint: {ckpt}") |
|
|
|
create_circle( |
|
|
|
ckpt, |
|
|
|