from lightning.pytorch.cli import LightningCLI # from callbacks import SaveImageCallback from datamodule import ColorDataModule from model import ColorTransformerModel def cli_main(): cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841 # note: don't call fit!! if __name__ == "__main__": cli_main() # note: it is good practice to implement the CLI in a function and call it in the main if block # save_img_callback = SaveImageCallback( # save_interval=0, # final_dir="out", # ) # trainer = pl.Trainer( # callbacks=[save_img_callback], # ) # data = ColorDataModule() # trainer.fit(model, data)