diff --git a/src/python/training/train_ldm.py b/src/python/training/train_ldm.py index 9f7d844..32b127f 100644 --- a/src/python/training/train_ldm.py +++ b/src/python/training/train_ldm.py @@ -74,12 +74,10 @@ def main(args): model_type="diffusion", ) - # Load Autoencoder to produce the latent representations print(f"Loading Stage 1 from {args.stage1_uri}") stage1 = mlflow.pytorch.load_model(args.stage1_uri) stage1.eval() - # Create the diffusion model print("Creating model...") config = OmegaConf.load(args.config_file) diffusion = DiffusionModelUNet(**config["ldm"].get("params", dict()))