diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 5700d35dc..4cc3c0aba 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -128,6 +128,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox print("o:", o, o.shape) return o, softmax_lse, sd_mask, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2): """Compute reference output and softmax_lse using PyTorch's built-in function""" @@ -193,7 +194,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout if layout == "bshd": o = o.transpose(1, 2) - return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + return o, softmax_lse, exp_scores, None, None, None, None def attention_varlen_forward_pytorch_ref_impl( @@ -222,9 +223,11 @@ def attention_varlen_forward_pytorch_ref_impl( # Pre-allocate outputs total_L_q = q.shape[0] + total_L_k = k.shape[0] o = torch.empty((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) softmax_lse = torch.empty((total_L_q, nheads_q), dtype=torch.float32, device=q.device) + sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) # Compute group_size for MQA/GQA handling group_size = nheads_q // nheads_k @@ -275,12 +278,12 @@ def attention_varlen_forward_pytorch_ref_impl( ( o_i, softmax_lse_i, - exp_scores_i, - softmax_i, - attention_shifted_scaled_scores_i, - attention_scaled_scores_i, - attention_scores_i, - ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, use_exp2) + sd_mask_i, + _, + _, + _, + _, + ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2) # Reshape outputs back to original dimensions if group_size != 1: @@ -298,15 +301,20 @@ def attention_varlen_forward_pytorch_ref_impl( # Convert back to 'thd' layout and float16 o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, nheads_q, head_dim] softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] + sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] + + # print("sd_mask_i: ", sd_mask_i) # Place outputs in pre-allocated tensors o[start_q:end_q, :, :] = o_i softmax_lse[start_q:end_q, :] = softmax_lse_i + + sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i return ( o, softmax_lse, - None, + sd_mask, None, None, None, @@ -354,11 +362,11 @@ def attention_forward_pytorch_ref_impl( ( o_ref, softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, + sd_mask_ref, + _, + _, + _, + _, ) = attention_varlen_forward_pytorch_ref_impl( q.clone(), k.clone(), @@ -379,11 +387,11 @@ def attention_forward_pytorch_ref_impl( ( o_ref, softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, + sd_mask_ref, + _, + _, + _, + _, ) = attention_vanilla_forward_pytorch_ref_impl(q.clone(), k.clone(), v.clone(), @@ -400,16 +408,16 @@ def attention_forward_pytorch_ref_impl( print("attention_forward_pytorch_ref_impl outputs") print("o_ref:", o_ref, o_ref.shape) print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None) + print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape if sd_mask_ref is not None else None) return ( o_ref, softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, + sd_mask_ref, + _, + _, + _, + _, ) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index b1c0ae5d5..dc31b7b71 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -286,9 +286,6 @@ def varlen_fwd( print("window_size_right:", window_size_right) print("gen_:", gen_) - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD's Triton Backend yet") - if o is None: o = torch.empty_like(q) @@ -309,6 +306,9 @@ def varlen_fwd( if dropout_p > 0.0: metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + else: + rng_state = None # Check arguments metadata.check_args(q, k, v, o) @@ -320,7 +320,7 @@ def varlen_fwd( print("Using reference implementation") (output, softmax_lse, - exp_scores, + sd_mask, _, _, _, @@ -335,6 +335,9 @@ def varlen_fwd( metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.use_exp2) o.copy_(output) else: @@ -342,7 +345,7 @@ def varlen_fwd( print("Using Triton implementation") (_, softmax_lse, - exp_scores, + sd_mask, _, _, _, @@ -356,23 +359,25 @@ def varlen_fwd( metadata.sm_scale, metadata.alibi_slopes, metadata.causal, - metadata.bias, - metadata.dropout_p, + metadata.bias, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, - metadata.max_seqlens_k, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.return_scores, metadata.use_exp2) if DEBUG: print("varlen_fwd outputs") print("o:", o, o.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None ) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - return o, softmax_lse, exp_scores, None + return o, softmax_lse, sd_mask, rng_state def varlen_bwd( dout, @@ -426,9 +431,10 @@ def varlen_bwd( print("gen_:", gen_) print("rng_state:", rng_state) - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD yet") - + if dropout_p > 0.0: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None if USE_REF: if DEBUG: print("Using reference implementation") @@ -446,6 +452,9 @@ def varlen_bwd( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, ) dq.copy_(dq_ref) @@ -473,6 +482,9 @@ def varlen_bwd( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, ) delta = delta_triton diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index a68c01d15..424c4b81e 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -593,9 +593,6 @@ def get_dropout_fraction( @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") @@ -753,9 +750,6 @@ def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: @@ -891,8 +885,8 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("deterministic", [False]) @@ -951,8 +945,8 @@ def test_flash_attn_output( device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 1 - nheads = 1 if softcap == 0.0 else 4 # softcap reference impl takes more memory + batch_size = 4 + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) @@ -1222,6 +1216,7 @@ def test_flash_attn_output( @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ + # (4, 4), (1, 147), (113, 203), (128, 217), @@ -1237,16 +1232,13 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) # @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('dropout_p', [0.17]) # @pytest.mark.parametrize("softcap", [0.0, 50.0]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") @@ -1526,7 +1518,7 @@ def test_flash_attn_varlen_output( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)