Browse Source

plotting args

new-sep-loss
Michael Pilosov, PhD 10 months ago
parent
commit
7ce24b0cd3
  1. 15
      check.py
  2. 2
      out/index.html
  3. 4
      utils.py

15
check.py

@ -10,7 +10,7 @@ from model import ColorTransformerModel
# import matplotlib.colors as mcolors # import matplotlib.colors as mcolors
def make_image(ckpt: str, fname: str, color=True): def make_image(ckpt: str, fname: str, color=True, **kwargs):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
# preds = M(rgb_tensor) # preds = M(rgb_tensor)
@ -36,10 +36,10 @@ def make_image(ckpt: str, fname: str, color=True):
ax.axis("off") ax.axis("off")
# ax.axis("square") # ax.axis("square")
plt.savefig(f"{fname}.png", dpi=300) plt.savefig(f"{fname}.png", **kwargs)
def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True): def create_circle(ckpt: str, fname: str, skip: bool = True, **kwargs):
if isinstance(ckpt, str): if isinstance(ckpt, str):
M = ColorTransformerModel.load_from_checkpoint(ckpt) M = ColorTransformerModel.load_from_checkpoint(ckpt)
else: else:
@ -48,7 +48,7 @@ def create_circle(ckpt: str, fname: str, dpi: int = 150, skip: bool = True):
rgb_tensor, _ = extract_colors() rgb_tensor, _ = extract_colors()
rgb_tensor = preprocess_data(rgb_tensor) rgb_tensor = preprocess_data(rgb_tensor)
preds = M(rgb_tensor.to(M.device)) preds = M(rgb_tensor.to(M.device))
plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, dpi=dpi) plot_preds(preds, rgb_tensor.detach().cpu().numpy(), fname=fname, **kwargs)
def plot_preds( def plot_preds(
@ -116,10 +116,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# make the following accept a list of arguments # make the following accept a list of arguments
parser.add_argument("-v", "--version", type=int, nargs="+", default=[0, 1]) parser.add_argument("-v", "--version", type=int, nargs="+", default=[0])
parser.add_argument( parser.add_argument(
"--dpi", type=int, default=150, help="Resolution for saved image." "--dpi", type=int, default=150, help="Resolution for saved image."
) )
parser.add_argument("--figsize", type=int, default=3, help="Figure size")
args = parser.parse_args() args = parser.parse_args()
versions = args.version versions = args.version
for v in versions: for v in versions:
@ -130,7 +131,7 @@ if __name__ == "__main__":
if len(ckpt) > 0: if len(ckpt) > 0:
ckpt = ckpt[-1] ckpt = ckpt[-1]
print(f"Generating image for checkpoint: {ckpt}") print(f"Generating image for checkpoint: {ckpt}")
create_circle(ckpt, fname=name, dpi=args.dpi) create_circle(ckpt, fname=name, dpi=args.dpi, figsize=[args.figsize] * 2)
else: else:
print(f"No checkpoint found for version {v}") print(f"No checkpoint found for version {v}")
# make_image(ckpt, fname=name + "b", color=False) # make_image(ckpt, fname=name + "b", color=False, dpi=args.dpi,)

2
out/index.html

@ -77,7 +77,7 @@
if (i == -21) { if (i == -21) {
imageName = 'hsv.png'; imageName = 'hsv.png';
} else { } else {
imageName = 'v' + i + '.png'; imageName = 'v' + i + '.jpg';
} }
let img = document.createElement('img'); let img = document.createElement('img');
img.src = imageName; img.src = imageName;

4
utils.py

@ -34,7 +34,9 @@ def extract_colors():
return rgb_tensor, xkcd_color_names return rgb_tensor, xkcd_color_names
PURE_RGB = preprocess_data(torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)) PURE_RGB = preprocess_data(
torch.cat([torch.eye(3), torch.eye(3) + torch.eye(3)[:, [1, 2, 0]]], dim=0)
)
PURE_HSV = torch.tensor( PURE_HSV = torch.tensor(
[[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32 [[0], [1 / 3], [2 / 3], [5 / 6], [1 / 6], [0.5]], dtype=torch.float32
) )

Loading…
Cancel
Save