@ -25,7 +25,7 @@ class SaveImageCallback(Callback):
# else:
# version = 0
fname = Path(pl_module.trainer.logger.log_dir) / Path(f"e{epoch:04d}")
create_circle(pl_module, fname=fname)
create_circle(pl_module, fname=fname, dpi=300, figsize=(6, 6))
# Make sure to set it back to train mode
pl_module.train()
@ -61,7 +61,7 @@ def create_circle(
def plot_preds(
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 150, figsize=(3, 3)
preds, rgb_values, fname: str, roll: bool = False, dpi: int = 300, figsize=(6, 6)
):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()