From 388265f7139d2920daed57f6da044e7e88841680 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 22 Nov 2024 14:05:53 +0800 Subject: [PATCH] add XPU LinearAddAdd op (#1017) Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b42966e3d..a892335ee 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -118,6 +118,21 @@ def forward( return hidden_states +class XPUlinearAddAdd(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward(self, x, y, z): + if self.bias is not None: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, self.bias, 1.0, y, 1.0) + x += z + else: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, z, 1.0, y, 1.0) + return x + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): return rms_norm(hidden_states, self.weight, self.variance_epsilon) @@ -703,7 +718,10 @@ def __init__(self, module, config) -> None: elif self.module_device == "xpu": self.linear_gelu = XPULinearGelu(module.dense_h_to_4h) if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: - self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + if self.module_device == "cpu": + self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + elif self.module_device == "xpu": + self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h) def forward( self,