from lightning.pytorch.cli import LightningCLI # from callbacks import SaveImageCallback from datamodule import ColorDataModule from model import ColorTransformerModel def cli_main(): cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841 # note: don't call fit!! if __name__ == "__main__": cli_main() # note: it is good practice to implement the CLI in a function and call it in the main if block