diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index fef5900f2..e2b640495 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -586,6 +586,7 @@ def sdpa_zp_stats_type(): layerwise=True, blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None, add_mul_node=True): + for (inp_args, inp_kwargs) in vae_calibration: input_args = tuple([ input_arg.cpu() if isinstance(input_arg, torch.Tensor) else input_arg @@ -594,6 +595,8 @@ def sdpa_zp_stats_type(): k: (v.cpu() if isinstance(v, torch.Tensor) else v) for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) + if args.dry_run or args.vae_load_checkpoint is not None: + break quantizers = generate_quantizers( dtype=dtype, @@ -646,7 +649,15 @@ def sdpa_zp_stats_type(): k: (v.cuda() if isinstance(v, torch.Tensor) else v) for (k, v) in vae_calibration[0][1].items()} pipe.vae.decode(*input_args, **input_kwargs) - if needs_calibration: + + if args.vae_load_checkpoint is not None: + with load_quant_model_mode(pipe.vae): + pipe = pipe.to('cpu') + print(f"Loading checkpoint: {args.vae_load_checkpoint}... ", end="") + pipe.vae.load_state_dict(torch.load(args.vae_load_checkpoint, map_location='cpu')) + print(f"Checkpoint loaded!") + pipe = pipe.to(args.device) + if needs_calibration and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.vae): for (inp_args, inp_kwargs) in vae_calibration: @@ -658,7 +669,7 @@ def sdpa_zp_stats_type(): for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) - if args.vae_gptq: + if args.vae_gptq and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying GPTQ") with torch.no_grad(), gptq_mode(pipe.vae, create_weight_orig=False, @@ -673,7 +684,7 @@ def sdpa_zp_stats_type(): k: (v.cuda() if isinstance(v, torch.Tensor) else v) for (k, v) in input_kwargs.items()} pipe.vae.decode(*inp_args, **inp_kwargs) - if args.vae_bias_correction: + if args.vae_bias_correction and not (args.dry_run or args.vae_load_checkpoint is not None): print("Applying Bias Correction") with torch.no_grad(), bias_correction_mode(pipe.vae): for inp_args, inp_kwargs in vae_calibration: @@ -830,6 +841,11 @@ def sdpa_zp_stats_type(): type=str, default=None, help='Path to checkpoint to load. If provided, PTQ techniques are skipped.') + parser.add_argument( + '--vae-load-checkpoint', + type=str, + default=None, + help='Path to checkpoint vae to load. If provided, PTQ techniques are skipped.') parser.add_argument( '--path-to-latents', type=str,