diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index fdc7ea86b..039d5201b 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" +_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -588,6 +589,44 @@ def postprocess_attention_output(self, attn_output): attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output + def varlen_attn(self, query, key, value, past_key_value, input_lens): + # prefill, remove padding + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + PagedAttention.flash_attn_varlen_func( + attn_output, + query, + key, + value, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + else: + varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + attn_output, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 0.0, + 1.0 / math.sqrt(self.head_dim), + False, + True, + False, + None, + ) + return attn_output + def forward( self, hidden_states: torch.Tensor, @@ -609,27 +648,15 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) + else: + key_cache, value_cache = key, value - attn_output = torch.empty_like(query) if past_len == 0: - # prefill, remove padding - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - PagedAttention.flash_attn_varlen_func( - attn_output, - query, - key_cache, - value_cache, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 1.0 / math.sqrt(self.head_dim), - True, - past_key_value.block_tables, - None, - ) + # prefill + attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) else: # decode + attn_output = torch.empty_like(query) PagedAttention.single_query_cached_kv_attention( attn_output, query,