Skip to content

Commit

Permalink
use sdpa for no cache forward
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 12, 2024
1 parent 36884cb commit d061e69
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,9 @@ def postprocess_attention_output(self, attn_output):

def varlen_attn(self, query, key, value, past_key_value, input_lens):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN):
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
Expand All @@ -609,21 +609,12 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens):
None,
)
else:
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
is_causal=True,
)
return attn_output

Expand Down

0 comments on commit d061e69

Please sign in to comment.