From 810402dea7e3fb6ca5530bcee77229bd6629009c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 21 Feb 2025 11:32:50 +0000 Subject: [PATCH 1/2] Feat (brevitas_examples/sdxl): load vae checkpoint --- .../stable_diffusion/main.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index fef5900f2..6a3dd6915 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -586,6 +586,7 @@ def sdpa_zp_stats_type(): layerwise=True, blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None, add_mul_node=True): + for (inp_args, inp_kwargs) in vae_calibration: input_args = tuple([ input_arg.cpu() if isinstance(input_arg, torch.Tensor) else input_arg @@ -594,6 +595,8 @@ def sdpa_zp_stats_type(): k: (v.cpu() if isinstance(v, torch.Tensor) else v) for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) + if args.dry_run or args.vae_load_checkpoint is not None: + break quantizers = generate_quantizers( dtype=dtype, @@ -646,7 +649,15 @@ def sdpa_zp_stats_type(): k: (v.cuda() if isinstance(v, torch.Tensor) else v) for (k, v) in vae_calibration[0][1].items()} pipe.vae.decode(*input_args, **input_kwargs) - if needs_calibration: + + if args.vae_load_checkpoint is not None: + with load_quant_model_mode(pipe.unet): + pipe = pipe.to('cpu') + print(f"Loading checkpoint: {args.vae_load_checkpoint}... ", end="") + pipe.vae.load_state_dict(torch.load(args.vae_load_checkpoint, map_location='cpu')) + print(f"Checkpoint loaded!") + pipe = pipe.to(args.device) + if needs_calibration and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.vae): for (inp_args, inp_kwargs) in vae_calibration: @@ -658,7 +669,7 @@ def sdpa_zp_stats_type(): for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) - if args.vae_gptq: + if args.vae_gptq and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying GPTQ") with torch.no_grad(), gptq_mode(pipe.vae, create_weight_orig=False, @@ -673,7 +684,7 @@ def sdpa_zp_stats_type(): k: (v.cuda() if isinstance(v, torch.Tensor) else v) for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) - if args.vae_bias_correction: + if args.vae_bias_correction and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying Bias Correction") with torch.no_grad(), bias_correction_mode(pipe.vae): for inp_args, inp_kwargs in vae_calibration: @@ -830,6 +841,11 @@ def sdpa_zp_stats_type(): type=str, default=None, help='Path to checkpoint to load. If provided, PTQ techniques are skipped.') + parser.add_argument( + '--vae-load-checkpoint', + type=str, + default=None, + help='Path to checkpoint vae to load. If provided, PTQ techniques are skipped.') parser.add_argument( '--path-to-latents', type=str, From 663a079b798be68517a0bce3cc9cc1e922f42fba Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 26 Feb 2025 10:29:56 +0000 Subject: [PATCH 2/2] Fix --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 6a3dd6915..e2b640495 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -651,7 +651,7 @@ def sdpa_zp_stats_type(): pipe.vae.decode(*input_args, **input_kwargs) if args.vae_load_checkpoint is not None: - with load_quant_model_mode(pipe.unet): + with load_quant_model_mode(pipe.vae): pipe = pipe.to('cpu') print(f"Loading checkpoint: {args.vae_load_checkpoint}... ", end="") pipe.vae.load_state_dict(torch.load(args.vae_load_checkpoint, map_location='cpu'))