16 lines
437 B
Python
16 lines
437 B
Python
|
from lightning.pytorch.cli import LightningCLI
|
||
|
|
||
|
# from callbacks import SaveImageCallback
|
||
|
from datamodule import ColorDataModule
|
||
|
from model import ColorTransformerModel
|
||
|
|
||
|
|
||
|
def cli_main():
|
||
|
cli = LightningCLI(ColorTransformerModel, ColorDataModule) # noqa: F841
|
||
|
# note: don't call fit!!
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
cli_main()
|
||
|
# note: it is good practice to implement the CLI in a function and call it in the main if block
|