diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 039d5201b2..91f406d3c5 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -591,9 +591,9 @@ def postprocess_attention_output(self, attn_output): def varlen_attn(self, query, key, value, past_key_value, input_lens): # prefill, remove padding - attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) PagedAttention.flash_attn_varlen_func( attn_output, query, @@ -609,21 +609,12 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens): None, ) else: - varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - attn_output, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 0.0, - 1.0 / math.sqrt(self.head_dim), - False, - True, - False, - None, + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=True, ) return attn_output