diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py index 087980de926..3cefad81fd2 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py @@ -133,6 +133,11 @@ action="store_true", help="Use determined group to do quantization", ) +parser.add_argument( + "--use_mse_search", + action="store_true", + help="Enables mean squared error (MSE) search.", +) # ============AUTOROUND configs============== parser.add_argument( "--lr", @@ -261,6 +266,7 @@ sym=True if args.scheme == "sym" else False, blocksize=args.blocksize, static_groups=args.static_groups, + use_mse_search=args.use_mse_search, group_size=args.group_size, n_samples=args.n_samples, seq_len=args.seq_len, diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index d85eb5b48e6..78df8ffcdee 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -636,6 +636,7 @@ def convert_to_quantized_model(model, config, device="cpu"): percdamp=config.damp_percent, block_size=config.blocksize, static_groups=config.static_groups, + use_mse_search=config.use_mse_search, ) if config.llm_int8_skip_modules != []: for module in config.llm_int8_skip_modules: diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 2314512db63..9e27c76baa0 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -879,6 +879,7 @@ def __init__( n_samples: int = 128, seq_len: int = 2048, static_groups: bool = False, + use_mse_search: bool = False, true_sequential: bool = False, layer_wise: bool = False, use_ggml: bool = False, @@ -908,6 +909,7 @@ def __init__( self.damp_percent = damp_percent self.desc_act = desc_act self.static_groups = static_groups + self.use_mse_search = use_mse_search self.true_sequential = true_sequential self.layer_wise = layer_wise self.seq_len = seq_len