Skip to content

Commit

Permalink
add check for fa op
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jan 22, 2025
1 parent ee852e4 commit 60ec7fd
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions deeplink_ext/internevo_ops/_flash_attention_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ def flash_attn_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

seqlen_q = q.shape[1]
seqlen_k = k.shape[1]
head_num = q.shape[-2]

assert seqlen_q == seqlen_k
assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k."
sparse_mode = 2 if seqlen_q > 2048 else 0
seqlen = min(seqlen_q, 2048)

Expand Down Expand Up @@ -78,14 +83,21 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_num = q.shape[-2]

cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()

assert max_seqlen_q == max_seqlen_k
assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
sparse_mode = 2 if max_seqlen_q > 2048 else 0
max_seqlen = min(max_seqlen_q, 2048)

Expand Down Expand Up @@ -126,6 +138,11 @@ def flash_attn_qkvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q = qkv[:, :, 0]
Expand Down Expand Up @@ -175,6 +192,11 @@ def flash_attn_kvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
k = kv[:, :, 0]
Expand All @@ -184,7 +206,7 @@ def flash_attn_kvpacked_func(
seqlen_kv = kv.shape[1]
head_num = q.shape[-2]

assert seqlen_q == seqlen_kv
assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv."
sparse_mode = 2 if seqlen_q > 2048 else 0
seqlen = min(seqlen_q, 2048)

Expand Down Expand Up @@ -226,6 +248,11 @@ def flash_attn_varlen_qkvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q = qkv[:, 0]
Expand Down Expand Up @@ -280,6 +307,11 @@ def flash_attn_varlen_kvpacked_func(
deterministic=False,
return_attn_probs=False,
):
assert window_size == (
-1,
-1,
), "Npu currently does not support sliding window attention"
assert alibi_slopes is None, "Npu currently does not support ALiBi."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
k = kv[:, 0]
Expand All @@ -288,7 +320,9 @@ def flash_attn_varlen_kvpacked_func(
cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_k = cu_seqlens_k[1:].tolist()

assert max_seqlen_q == max_seqlen_k
assert (
max_seqlen_q == max_seqlen_k
), "Npu currently only supports max_seqlen_q = max_seqlen_k."
sparse_mode = 2 if max_seqlen_q > 2048 else 0
max_seqlen = min(max_seqlen_q, 2048)

Expand Down

0 comments on commit 60ec7fd

Please sign in to comment.