From f9c021b4c1130dac4eb069ea3161aaf182449c2c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:51:14 +0000 Subject: [PATCH] revert shape for sdpa Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 18 +++++++++--------- optimum/intel/ipex/modeling_base.py | 2 ++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f84c1cd548..336b3871b2 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -181,7 +181,7 @@ def _llama_model_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -195,7 +195,7 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -298,7 +298,7 @@ def _falcon_model_forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -310,7 +310,7 @@ def _falcon_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -420,7 +420,7 @@ def _gpt2_model_forward( past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -437,7 +437,7 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) - if past_length == 0: + if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -636,9 +636,9 @@ def forward( # prefill if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, + query.reshape(input_lens.shape[0], -1, query.shape[-1]), + key.reshape(input_lens.shape[0], -1, key.shape[-1]), + value.reshape(input_lens.shape[0], -1, value.shape[-1]), attn_mask=None, is_causal=True, ) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8611bddd21..d8f830e519 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -276,6 +276,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: + if self.add_patch and input_ids is not None and attention_mask is None: + attention_mask = torch.ones_like(input_ids) return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def _prepare_generation_config(