|
@ -8,8 +8,10 @@ from model import ColorTransformerModel |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
# import matplotlib.colors as mcolors |
|
|
# import matplotlib.colors as mcolors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_image(ckpt: str, fname: str, color=True): |
|
|
def make_image(ckpt: str, fname: str, color=True): |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
|
|
|
|
|
@ -41,6 +43,7 @@ def make_image(ckpt: str, fname: str, color=True): |
|
|
|
|
|
|
|
|
def create_circle(ckpt: str, fname: str): |
|
|
def create_circle(ckpt: str, fname: str): |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
M = ColorTransformerModel.load_from_checkpoint(ckpt) |
|
|
|
|
|
M.eval() |
|
|
rgb_tensor, names = extract_colors() |
|
|
rgb_tensor, names = extract_colors() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_values = rgb_tensor.detach().numpy() |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
|
rgb_tensor = preprocess_data(rgb_tensor) |
|
@ -69,6 +72,7 @@ def create_circle(ckpt: str, fname: str): |
|
|
ax.axis("off") |
|
|
ax.axis("off") |
|
|
fig.tight_layout() |
|
|
fig.tight_layout() |
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
plt.savefig(f"{fname}.png", dpi=300) |
|
|
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|