Skip to content

Commit

Permalink
add XPU LinearAddAdd op (#1017)
Browse files Browse the repository at this point in the history
Signed-off-by: Liu, Kaixuan <[email protected]>
  • Loading branch information
kaixuanliu authored Nov 22, 2024
1 parent a5c48a8 commit 388265f
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def forward(
return hidden_states


class XPUlinearAddAdd(torch.nn.Module):
def __init__(self, module: torch.nn.Module):
super().__init__()
self.weight = module.weight.transpose(0, 1).contiguous()
self.bias = module.bias

def forward(self, x, y, z):
if self.bias is not None:
x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, self.bias, 1.0, y, 1.0)
x += z
else:
x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, z, 1.0, y, 1.0)
return x


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def _ipex_rms_layer_norm_forward(self, hidden_states):
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
Expand Down Expand Up @@ -703,7 +718,10 @@ def __init__(self, module, config) -> None:
elif self.module_device == "xpu":
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
if self.module_device == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
elif self.module_device == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)

def forward(
self,
Expand Down

0 comments on commit 388265f

Please sign in to comment.