diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index 481505502ca..81bf61879e2 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -532,7 +532,7 @@ def default_calib_func(model): "autoround_args": { "n_samples": config.nsamples, "seqlen": config.calib_len, - "iters": config.iters, + "iters": config.calib_iters, "scale_dtype": config.scale_dtype, "enable_quanted_input": not config.disable_quanted_input, "lr": config.lr, diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 47a922f30cd..65eb4158702 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -1065,7 +1065,7 @@ def __init__( minmax_lr: float = None, disable_quanted_input: bool = False, nsamples: int = 512, - iters: int = 200, + iters: int = None, use_ggml: bool = False, use_neural_speed: bool = False, llm_int8_skip_modules=None, @@ -1091,7 +1091,6 @@ def __init__( self.lr = lr self.minmax_lr = minmax_lr self.disable_quanted_input = disable_quanted_input - self.iters = iters self.llm_int8_skip_modules = ( llm_int8_skip_modules if llm_int8_skip_modules else [] ) @@ -1101,7 +1100,14 @@ def __init__( self.calib_dataloader = kwargs.get("calib_dataloader", None) self.calib_len = kwargs.get("calib_len", 2048) self.calib_func = kwargs.get("calib_func", None) - self.calib_iters = kwargs.get("calib_iters", 100) + calib_iters = kwargs.get("calib_iters", None) + if iters is not None: + self.calib_iters = iters + if calib_iters is not None: + logger.info("cannot be set simultaneously for 'iters' and 'calib_iters', " + "we will use 'iters' as calibration iterations!") + else: + self.calib_iters = 200 if calib_iters is None else calib_iters self.scheme = "sym" if self.sym else "asym" if isinstance(compute_dtype, torch.dtype): self.compute_dtype = convert_dtype_torch2str(compute_dtype)