diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 2314512db63..3f328c2ff33 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -831,7 +831,10 @@ def __init__( self.double_quant_bits = double_quant_bits self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size - self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"]) + # "transformer.output_layer" for chatglm series model. + # "embed_out" for dolly v2 series model. + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", + ["lm_head", "transformer.output_layer", "embed_out"]) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed @@ -911,7 +914,8 @@ def __init__( self.true_sequential = true_sequential self.layer_wise = layer_wise self.seq_len = seq_len - self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"]) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", + ["lm_head", "transformer.output_layer", "embed_out"]) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed @@ -1009,7 +1013,8 @@ def __init__( self.seq_len = seq_len self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype - self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"]) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", + ["lm_head", "transformer.output_layer", "embed_out"]) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed @@ -1078,7 +1083,8 @@ def __init__( self.seq_len = seq_len self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype - self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"]) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", + ["lm_head", "transformer.output_layer", "embed_out"]) self.use_ggml = use_ggml self.use_neural_speed = use_neural_speed self.device = kwargs.get("device", "auto") @@ -1154,7 +1160,8 @@ def __init__( self.iters = iters self.seq_len = seq_len self.quant_lm_head = quant_lm_head - self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"]) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", + ["lm_head", "transformer.output_layer", "embed_out"]) if self.quant_lm_head: self.llm_int8_skip_modules = [] self.use_ggml = use_ggml