Skip to content

Commit

Permalink
wrap varlen- launcher attention_forward_pytorch_ref_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 16, 2024
1 parent b9b1f24 commit 52f6bdc
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 43 deletions.
59 changes: 58 additions & 1 deletion flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):

return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scores

def attention_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
"""Compute reference output and softmax_lse using PyTorch's built-in function"""

# Ensure the layout is 'bhsd'
Expand Down Expand Up @@ -181,6 +181,63 @@ def attention_varlen_forward_pytorch_ref_impl(
)


def attention_forward_pytorch_ref_impl(
q,
k,
v,
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2
):
# compute reference
if layout == "thd":
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_varlen_forward_pytorch_ref_impl(
q.clone(),
k.clone(),
v.clone(),
sm_scale,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
)
else:
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_vanilla_forward_pytorch_ref_impl(
q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
)

return (
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
)


def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
Expand Down
77 changes: 35 additions & 42 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .utils import MetaData, get_input_shapes, input_helper, varlen_input_helper
from .interface_torch import attention_prefill, attention_decode
from .fwd_ref import attention_forward_pytorch_ref_impl, attention_varlen_forward_pytorch_ref_impl, compute_alibi_tensor_ref
from .fwd_ref import attention_forward_pytorch_ref_impl, compute_alibi_tensor_ref
from .fwd_prefill import attention_prefill_forward_triton_impl
from .bwd_prefill import attention_prefill_backward_triton_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
Expand Down Expand Up @@ -387,39 +387,27 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor
# call Triton's forward implementation directly
o, softmax_lse_triton, exp_scores_triton, grid, head_size, philox_seed, philox_offset, _, _ = attention_prefill_forward_triton_impl(q, k, v, o, metadata)

# compute reference
if layout == "thd":
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_varlen_forward_pytorch_ref_impl(
q.clone(),
k.clone(),
v.clone(),
metadata.sm_scale,
causal,
layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
use_exp2,
)
else:
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_forward_pytorch_ref_impl(
q.clone(), k.clone(), v.clone(), metadata.sm_scale, causal, layout, use_exp2
)
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_forward_pytorch_ref_impl(
q.clone(),
k.clone(),
v.clone(),
metadata.sm_scale,
causal,
layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
use_exp2
)

if DEBUG:
# ref output
print("attention_scores_ref:", attention_scores_ref, attention_scores_ref.shape)
Expand Down Expand Up @@ -511,12 +499,7 @@ def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, b
dtype = torch.float16
torch.manual_seed(20) # seed from test_op_bwd

if DEBUG_INPUT:
sm_scale = 1
else:
sm_scale = D_HEAD ** -0.5
alibi_slopes = None

if layout == "thd":
q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT)
else:
Expand All @@ -538,7 +521,17 @@ def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, b
attention_shifted_scaled_scores_ref,
attention_scores_ref,
) = attention_forward_pytorch_ref_impl(
q_ref, k_ref, v_ref, sm_scale, causal, layout, use_exp2
q_ref,
k_ref,
v_ref,
metadata.sm_scale,
causal,
layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
use_exp2
)
if DEBUG:
print()
Expand All @@ -564,7 +557,7 @@ def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, b
v_ref,
o_ref,
softmax_lse_ref,
sm_scale,
metadata.sm_scale,
causal,
layout,
use_exp2,
Expand All @@ -584,7 +577,7 @@ def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, b
dq,
dk,
dv,
sm_scale,
metadata.sm_scale,
alibi_slopes,
causal,
layout,
Expand Down

0 comments on commit 52f6bdc

Please sign in to comment.