Skip to content

Commit

Permalink
revert shape for sdpa
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 12, 2024
1 parent 73a5ef7 commit f9c021b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
18 changes: 9 additions & 9 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _llama_model_forward(
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand All @@ -195,7 +195,7 @@ def _llama_model_forward(
next_decoder_cache = () if use_cache else None

position_embeddings = self.rotary_emb(hidden_states, position_ids)
if past_key_values_length == 0:
if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand Down Expand Up @@ -298,7 +298,7 @@ def _falcon_model_forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand All @@ -310,7 +310,7 @@ def _falcon_model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

if past_key_values_length == 0:
if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand Down Expand Up @@ -420,7 +420,7 @@ def _gpt2_model_forward(
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
Expand All @@ -437,7 +437,7 @@ def _gpt2_model_forward(

hidden_states = self.drop(hidden_states)

if past_length == 0:
if past_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand Down Expand Up @@ -636,9 +636,9 @@ def forward(
# prefill
if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
query.reshape(input_lens.shape[0], -1, query.shape[-1]),
key.reshape(input_lens.shape[0], -1, key.shape[-1]),
value.reshape(input_lens.shape[0], -1, value.shape[-1]),
attn_mask=None,
is_causal=True,
)
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if self.add_patch and input_ids is not None and attention_mask is None:
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
Expand Down

0 comments on commit f9c021b

Please sign in to comment.