You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
651 B

from lightning.pytorch.cli import LightningCLI
# from callbacks import SaveImageCallback
from datamodule import ColorDataModule
from model import ColorTransformerModel
def cli_main():
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841
# note: don't call fit!!
if __name__ == "__main__":
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block
# save_img_callback = SaveImageCallback(
# save_interval=0,
# final_dir="out",
# )
# trainer = pl.Trainer(
# callbacks=[save_img_callback],
# )
# data = ColorDataModule()
# trainer.fit(model, data)