Skip to content

Commit

Permalink
fix usage without pkv
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 12, 2024
1 parent 422134f commit 36884cb
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
logger = logging.getLogger(__name__)

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0"


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
Expand Down Expand Up @@ -588,6 +589,44 @@ def postprocess_attention_output(self, attn_output):
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
return attn_output

def varlen_attn(self, query, key, value, past_key_value, input_lens):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN):
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
key,
value,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
None,
)
else:
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
)
return attn_output

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -609,27 +648,15 @@ def forward(

if past_key_value is not None:
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)
else:
key_cache, value_cache = key, value

attn_output = torch.empty_like(query)
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
key_cache,
value_cache,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
None,
)
# prefill
attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens)
else:
# decode
attn_output = torch.empty_like(query)
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
Expand Down

0 comments on commit 36884cb

Please sign in to comment.