From 3386f54d7082d68fe9e5ff52b43ce46998b50694 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 13:06:42 +0000 Subject: [PATCH] fix autotp linear check Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e741575ed..8d5f8afa1 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -664,9 +664,9 @@ def __init__(self, module, config) -> None: if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() self.concat_linear.bias = nn.Parameter(concat_bias) - self.q_slice = self.q_proj.out_features - self.k_slice = self.q_slice + self.k_proj.out_features - self.v_slice = self.k_slice + self.v_proj.out_features + self.q_slice = self.q_proj.weight.shape[0] + self.k_slice = self.q_slice + self.k_proj.weight.shape[0] + self.v_slice = self.k_slice + self.v_proj.weight.shape[0] if self.module_device.type == "cpu": if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj)