From 8515e51fbe5b6dce7060d6a383f2520465cbc2ea Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 27 Nov 2024 15:27:36 +0800 Subject: [PATCH] fix glm4-9b overflow --- .../llm/src/ipex_llm/transformers/convert.py | 6 ++ .../ipex_llm/transformers/models/chatglm4.py | 65 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 7a1274d502f..f6b159c32d1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1477,6 +1477,12 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.ChatGLMModel, chatglm4_model_forward) convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward) convert_forward(model, module.MLP, mlp_forward) + + if model.config.num_layers == 40: + # workaround glm4-9b fp16 overflow + from ipex_llm.transformers.models.chatglm4 import chatglm4_block_forward + convert_forward(model, module.GLMBlock, chatglm4_block_forward) + elif "mpt" in model.config.model_type: if model.config.architectures is not None: modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 72ac00a1a4a..47d5ea1221b 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -363,3 +363,68 @@ def chatglm4_encoder_forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states, presents, all_hidden_states, all_self_attentions + + +def chatglm4_block_forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, +): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, + training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # ipex-llm changes start: workaround fp16 overflow + scale = 10 + if self.layer_number == 39 and layernorm_output.device.type == 'xpu': + gate = self.mlp.gate_proj(layernorm_output) + up = self.mlp.up_proj(layernorm_output) / scale + down = self.mlp.activation_fn(gate) * up + mlp_output = self.mlp.dense_4h_to_h(down) + else: + # MLP. + mlp_output = self.mlp(layernorm_output) + # ipex-llm changes end + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, + training=self.training) + + # ipex-llm changes start: workaround fp16 overflow + if self.layer_number == 39 and layernorm_output.device.type == 'xpu': + output = residual + output * scale + else: + output = residual + output + # ipex-llm changes end + + return output, kv_cache