From a1689993003535dd81690b84f9238d60c756fd36 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 14 Oct 2024 21:10:46 -0500 Subject: [PATCH] Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save --- .../flash_attn_triton_amd/bwd_prefill.py | 262 +++++++++---- flash_attn/flash_attn_triton_amd/bwd_ref.py | 10 +- .../flash_attn_triton_amd/fwd_prefill.py | 4 +- flash_attn/flash_attn_triton_amd/fwd_ref.py | 29 +- .../flash_attn_triton_amd/interface_fa.py | 99 ++++- .../flash_attn_triton_amd/interface_torch.py | 5 +- flash_attn/flash_attn_triton_amd/test.py | 196 ++++++---- flash_attn/flash_attn_triton_amd/utils.py | 72 ++-- tests/test_flash_attn_triton.py | 348 ++++++------------ 9 files changed, 586 insertions(+), 439 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 8fc644e96..9d847b550 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -3,11 +3,12 @@ import triton.language as tl from .bwd_ref import attention_backward_pytorch_ref_impl +from .utils import get_shape_from_layout, get_strides_from_layout DEBUG = False @triton.jit -def _bwd_preprocess_use_o( +def _bwd_preprocess_use_o_old( Out, DO, Delta, @@ -36,6 +37,51 @@ def _bwd_preprocess_use_o( tl.store(Delta + off_m, delta) + +@triton.jit +def _bwd_preprocess_use_o( + Out, + DO, + Delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_doz, stride_doh, stride_dom, stride_dok, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + N_CTX_Q: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + + # Compute batch and head indices + batch_idx = pid_bh // H + head_idx = pid_bh % H + + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_d = tl.arange(0, BLOCK_DMODEL) + + # create masks + mask_m = off_m < N_CTX_Q + mask_d = off_d < ACTUAL_BLOCK_DMODEL + + # compute pointers using strides + out_ptrs = Out + batch_idx * stride_oz + head_idx * stride_oh + off_m[:, None] * stride_om + off_d[None, :] * stride_ok + do_ptrs = DO + batch_idx * stride_doz + head_idx * stride_doh + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok + + # load + o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + + # compute delta + delta = tl.sum(o * do, axis=1) + + # write-back delta + delta_ptrs = Delta + pid_bh * N_CTX_Q + off_m + tl.store(delta_ptrs, delta, mask=mask_m) + + @triton.jit def _bwd_preprocess_use_p( Q, # Pointer to queries @@ -385,7 +431,7 @@ def _bwd_kernel( else: dq_offset = DQ + off_z * stride_qz + off_h * stride_qh - # inner loop + # inner loop if SEQUENCE_PARALLEL: _bwd_kernel_one_col_block( Q, @@ -494,11 +540,31 @@ def _bwd_kernel( USE_EXP2=USE_EXP2, ) -# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. -def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o, BLOCK_M=64, BLOCK_N=64): - - DEBUG_INPUT=False +# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. +def attention_prefill_backward_triton_new_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + use_exp2: bool, + bwd_preprocessing_use_o: bool, + BLOCK_M=64, + BLOCK_N=64, +): if DEBUG: print() print("attention_prefill_backward_triton_new_impl") @@ -512,7 +578,6 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, print("dk:", dk, dk.shape if dk is not None else None) print("dv:", dv, dv.shape if dv is not None else None) print("sm_scale:", sm_scale) - print("head_size:", head_size) print("alibi_slopes:", alibi_slopes) print("layout:", layout) print("use_exp2:", use_exp2) @@ -520,35 +585,37 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, print("BLOCK_M:", BLOCK_M) print("BLOCK_N:", BLOCK_N) - # the kernel wants bhsd - if layout == "bshd": - print("Changing layout to bhsd!") - do = do.transpose(1, 2).contiguous() - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() - o = o.transpose(1, 2).contiguous() - # TODO: does L/M need to be transposed. possible to use strides - elif layout == "bhsd": - pass + # make contigious + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + softmax_lse = softmax_lse.contiguous() + + # get strides and shape + if True: + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + stride_qz, stride_qh, stride_qm, stride_qk = q_strides + stride_kz, stride_kh, stride_kn, stride_kk = k_strides + stride_vz, stride_vh, stride_vn, stride_vk = v_strides + stride_oz, stride_oh, stride_om, stride_ok = o_strides + stride_dq_all = q.numel() + batch_headsize = batch * nheads_q else: - raise ValueError(f"Unknown layout {layout}") + batch_q, heads_q, seqlen_q, head_size_q = q.shape + batch_k, heads_k, seqlen_k, head_size_k = k.shape + batch_headsize = batch_q * heads_q + stride_dq_all = dq.numel() + stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3) + stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3) + stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3) sequence_parallel = False causal = False - batch_q, heads_q, N_CTX_Q, head_size_q = q.shape - batch_k, heads_k, N_CTX_K, head_size_k = k.shape - - assert (batch_q == batch_k) - assert (heads_q == heads_k) # just for now - assert (head_size_q == head_size_q == head_size) - - batch = batch_q - # divide up the problem - num_blocks_m = triton.cdiv(N_CTX_Q, BLOCK_M) - num_blocks_n = triton.cdiv(N_CTX_K, BLOCK_N) + num_blocks_m = triton.cdiv(seqlen_q, BLOCK_M) + num_blocks_n = triton.cdiv(seqlen_k, BLOCK_N) # get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() @@ -563,9 +630,13 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, new_dq_shape = (replicas,) + q.shape if dq is None: dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = dq.contiguous() else: if dq is None: dq = torch.zeros_like(q, dtype=q.dtype) + else: + dq = dq.contiguous() # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros dq.zero_() @@ -575,12 +646,16 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, dk = torch.zeros_like(k) else: dk = torch.empty_like(k) + else: + dk = dk.contiguous() if dv is None: if True: dv = torch.zeros_like(v) else: dv = torch.empty_like(v) + else: + dv = dv.contiguous() # assert contigious assert do.is_contiguous() @@ -593,11 +668,6 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, assert dk.is_contiguous() assert dv.is_contiguous() - batch_headsize = batch * heads_q - stride_dq_all = dq.numel() - stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3) - stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3) - stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3) num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful num_stages = 1 @@ -605,20 +675,34 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, delta = torch.zeros_like(softmax_lse) else: delta = torch.empty_like(softmax_lse) - if bwd_preprocessing_use_o: - _bwd_preprocess_use_o[(batch_headsize * num_blocks_m,)]( - o, - do, - delta, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - N_CTX_Q=N_CTX_Q - ) + if False: + _bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)]( + o, + do, + delta, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + N_CTX_Q=seqlen_q + ) + else: + _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + o, + do, + delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_oz, stride_oh, stride_om, stride_ok, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + N_CTX_Q=seqlen_q, + Z=batch, + H=nheads_q, + ) else: - _bwd_preprocess_use_p[(num_blocks_m, batch_headsize)]( + _bwd_preprocess_use_p[(num_blocks_m, batch_headsize)]( q, k, v, @@ -639,10 +723,10 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, stride_vh, stride_vn, stride_vk, - Z=batch_q, - H=heads_q, - N_CTX_Q=N_CTX_Q, - N_CTX_K=N_CTX_K, + Z=batch, + H=nheads_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL, @@ -650,8 +734,7 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, USE_EXP2=use_exp2, ) - - if False: + if DEBUG: print("_bwd_kernel inputs") print("do:", do, do.shape) print("q:", q, q.shape) @@ -667,12 +750,10 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk) print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk) print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk) - print("batch_q:", batch_q) - print("heads_q:",heads_q) - print("N_CTX_Q:",N_CTX_Q) - print("N_CTX_K:",N_CTX_K) - print("batch_q * head_size_q * N_CTX_Q:",batch_q * head_size_q * N_CTX_Q) - print("num_blocks_n * batch_q * head_size_q * N_CTX_Q:",num_blocks_n * batch_q * head_size_q * N_CTX_Q) + print("batch_q:", batch) + print("heads_q:",nheads_q) + print("seqlen_q:",seqlen_q) + print("seqlen_k:",seqlen_k) print("BLOCK_M:",BLOCK_M) print("BLOCK_N:",BLOCK_M) print("BLOCK_DMODEL:",BLOCK_DMODEL) @@ -699,10 +780,10 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, - batch_q, - heads_q, - N_CTX_Q, - N_CTX_K, + batch, + nheads_q, + seqlen_q, + seqlen_k, num_blocks_m, num_blocks_n, BLOCK_M=BLOCK_M, @@ -719,23 +800,53 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, if len(dq.shape) == 5: dq = dq.sum(dim=0) - # go back to original layout - if layout == "bshd": - print("Changing back to bshd!") - dq = dq.transpose(1, 2) - dk = dk.transpose(1, 2) - dv = dv.transpose(1, 2) - elif layout == "bhsd": - pass - else: - raise ValueError(f"Unknown layout {layout}") - return dq, dk, dv, delta, None, None -def attention_prefill_backward_triton_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o, use_new): +def attention_prefill_backward_triton_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + use_exp2: bool, + bwd_preprocessing_use_o: bool, + use_new, +): if use_new: - return attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o) + return attention_prefill_backward_triton_new_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + use_exp2, + bwd_preprocessing_use_o, + ) else: # test pytorch impl dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( @@ -757,4 +868,3 @@ def attention_prefill_backward_triton_impl(do, q, k, v, o, softmax_lse, dq, dk, dv = dv_ref return dq, dk, dv, delta_ref, None, None - diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 5576937d1..7ae61ee33 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -1,12 +1,14 @@ import torch import math -DEBUG=False +DEBUG = False def attention_backward_pytorch_ref_impl(do, q, k, v, o, softmax_lse, sm_scale, causal, layout, use_exp2, bwd_preprocessing_use_o): # ensure the layout is 'bhsd' if layout == "bshd": - print("Changing layout to bhsd!") + if DEBUG: + print() + print("Changing layout to bhsd!") do = do.transpose(1, 2).contiguous() q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() @@ -66,7 +68,9 @@ def attention_backward_pytorch_ref_impl(do, q, k, v, o, softmax_lse, sm_scale, c # go back to original layout if layout == "bshd": - print("Changing back to bshd!") + if DEBUG: + print() + print("Changing back to bshd!") dq = dq.transpose(1, 2) dk = dk.transpose(1, 2) dv = dv.transpose(1, 2) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 8a6543bd7..02ae4eaae 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -486,8 +486,8 @@ def attention_prefill_forward_triton_impl(q, k, v, o, metadata): o = torch.empty_like(q, dtype=v.dtype) metadata.check_args(q, k, v, o) - batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata.layout) # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index c1536fcb8..4a6fb075a 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -1,12 +1,22 @@ import math import torch +DEBUG = False + def attention_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2): """compute reference output and softmax_lse using PyTorch's built-in function""" - # expects bhsd layout - if layout != "bhsd": - raise ValueError("bhsd is the only layout supported") + # ensure the layout is 'bhsd' + if layout == "bshd": + if DEBUG: + print("Changing layout to bhsd!") + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") # get seqlens N_CTX_Q = q.shape[2] @@ -66,6 +76,19 @@ def attention_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_ex o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16) + # go back to original layout + if layout == "bshd": + if DEBUG: + print("Changing back to bshd!") + if use_exp2: + o_exp2 = o_exp2.transpose(1, 2) + else: + o = o.transpose(1, 2) + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") + if use_exp2: return o_exp2, softmax_exp2_lse, exp2_scores, softmax_exp2, attention_shifted_scaled_scores, attention_scores else: diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 8183b0409..21e6cd23d 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -1,9 +1,9 @@ import torch import triton -from .utils import MetaData, get_shape_from_layout from .fwd_prefill import attention_prefill_forward_triton_impl from .bwd_prefill import attention_prefill_backward_triton_impl from .fwd_decode import attention_decode_forward_triton_impl +from .utils import MetaData, get_shape_from_layout DEBUG = False @@ -54,7 +54,7 @@ def fwd(q, if return_softmax: input_metadata.return_encoded_softmax = True - batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, input_metadata.layout) if causal: input_metadata.need_causal() @@ -117,7 +117,6 @@ def bwd( if dropout_p != 0.0: raise ValueError("dropout is not supported on AMD yet") - batch, max_seqlens_q, nheads_q, head_size = q.shape _, _, _, _, _, _ = attention_prefill_backward_triton_impl( dout, q, @@ -129,24 +128,19 @@ def bwd( dk, dv, softmax_scale, - head_size, alibi_slopes, causal, "bshd", + None, + None, + None, + None, False, True, True, ) - softmax_d = None # fill this in - if False: - print() - print("bwd output") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("softmax_d:", softmax_d) - print() + softmax_d = None return dq, dk, dv, softmax_d def varlen_fwd( @@ -172,6 +166,24 @@ def varlen_fwd( return_softmax, gen_): + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("gen_:", gen_) + if dropout_p != 0.0: raise ValueError("dropout is not supported on AMD's Triton Backend yet") @@ -185,7 +197,7 @@ def varlen_fwd( input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata # get shapes - batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) + batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, input_metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) if causal: input_metadata.need_causal() @@ -199,9 +211,9 @@ def varlen_fwd( # Check arguments input_metadata.check_args(q, k, v, o) - tri_out, softmax_lse, softmax_dmask = attention_prefill_forward_triton_impl(q, k, v, o, input_metadata) + o_triton, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted = attention_prefill_forward_triton_impl(q, k, v, o, input_metadata) - return tri_out, q , k , v, o, softmax_lse, softmax_dmask, None + return o_triton, q , k , v, o, softmax_lse, exp_scores, None def varlen_bwd( dout, @@ -229,7 +241,60 @@ def varlen_bwd( gen_, rng_state, ): - raise ValueError("varlen_bwd is not supported on AMD's Triton Backend yet") + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_bwd") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD yet") + + _, _, _, _, _, _ = attention_prefill_backward_triton_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + False, + True, + True, + ) + + softmax_d = None + return dq, dk, dv, softmax_d def fwd_kvcache( q, diff --git a/flash_attn/flash_attn_triton_amd/interface_torch.py b/flash_attn/flash_attn_triton_amd/interface_torch.py index c33810206..b82194c99 100644 --- a/flash_attn/flash_attn_triton_amd/interface_torch.py +++ b/flash_attn/flash_attn_triton_amd/interface_torch.py @@ -39,10 +39,13 @@ def backward(ctx, do, *args): # expects bhsd None, None, ctx.sm_scale, - ctx.head_size, ctx.alibi_slopes, ctx.causal, ctx.layout, + None, + None, + None, + None, ctx.use_exp2, ctx.bwd_preprocessing_use_o, True, diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 7a1478d01..92fe802f0 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -9,7 +9,7 @@ from .bwd_ref import attention_backward_pytorch_ref_impl from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 -DEBUG=False +DEBUG = True # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. @@ -227,7 +227,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): torch.manual_seed(20) - DEBUG_INPUT = False # if DEBUG is True it fails + DEBUG_INPUT = False # seqlens seqlen_q = N_CTX_Q @@ -341,16 +341,17 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali (1, 1, 256, 512, 16), (1, 1, 128, 128, 64), (2, 4, 1024, 1024, 64), + (4, 6, 108, 256, 224), (4, 8, 2048, 2048, 128), (4, 16, 4096, 4096, 64), (2, 4, 8192, 8192, 32), ]) @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('return_scores', [False]) -@pytest.mark.parametrize('check_softmax', [True, False]) +@pytest.mark.parametrize('layout', ["bshd", "bhsd"]) @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues -def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, check_softmax, use_exp2, DEBUG_INPUT): +def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT): dtype = torch.float16 torch.manual_seed(0) @@ -358,21 +359,36 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor sm_scale = 1 else: sm_scale = D_HEAD ** -0.5 - layout = 'bhsd' alibi_slopes = None dropout_p = 0.0 if DEBUG_INPUT: - q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, 1, N_CTX_Q, 1).expand(Z, H, N_CTX_Q, D_HEAD).requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).requires_grad_() - o = torch.zeros_like(q) + if layout == 'bhsd': + q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, 1, N_CTX_Q, 1).expand(Z, H, N_CTX_Q, D_HEAD).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() + o = torch.zeros_like(q).contiguous() + elif layout == "bshd": + q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, N_CTX_Q, 1, 1).expand(Z, N_CTX_Q, H, D_HEAD).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, N_CTX_K, 1, 1).expand(Z, N_CTX_K, H, D_HEAD).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, N_CTX_K, 1, 1).expand(Z, N_CTX_K, H, D_HEAD).contiguous().requires_grad_() + o = torch.zeros_like(q).contiguous() + else: + raise ValueError("Unknown layout") else: - # Generate random inputs - q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype) - k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype) - v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype) - o = torch.empty_like(q) + if layout == 'bhsd': + # Generate random inputs + q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + o = torch.empty_like(q) + elif layout == 'bshd': + q = torch.randn(Z, N_CTX_Q, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + k = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + v = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + o = torch.empty_like(q) + else: + raise ValueError("Unknown layout") # Set up metadata input_metadata = MetaData(sm_scale=sm_scale) @@ -384,7 +400,7 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor input_metadata.need_causal() # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - if check_softmax or return_scores: + if return_scores: input_metadata.return_scores = True # call Triton's forward implementation directly @@ -399,7 +415,7 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor attention_shifted_scaled_scores_ref, attention_scores_ref, ) = attention_forward_pytorch_ref_impl( - q.clone(), k.clone(), v.clone(), sm_scale, causal, "bhsd", use_exp2 + q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2 ) if DEBUG: # ref output @@ -415,29 +431,36 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor print("ref_o:", o_ref, o_ref.shape) torch.testing.assert_close(o, o_ref, atol=ATOL, rtol=RTOL) - # compare with pytorch - out_pytorch, softmax_pytorch = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) if DEBUG: - print("o:", o, o.shape) - print("out_pytorch:", out_pytorch, out_pytorch.shape) - torch.testing.assert_close(o, out_pytorch, atol=ATOL, rtol=RTOL) + # compare softmax_lse with ref + print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) + print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) + torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) - if check_softmax: - if DEBUG: - # compare softmax_lse with ref - print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) - print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) + # use trick with lse to get the softmax. you need the scores but is it + softmax_triton = torch.exp(sm_scale * attention_scores_ref - softmax_lse_triton.unsqueeze(-1)) + if DEBUG: + print("softmax_triton:", softmax_triton, softmax_triton.shape) + print("softmax_ref:", softmax_ref, softmax_ref.shape) + torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) - # use trick with lse to get the softmax. you need the scores but is it - softmax_triton = torch.exp(sm_scale * attention_scores_ref - softmax_lse_triton.unsqueeze(-1)) - if DEBUG: - print("softmax_triton:", softmax_triton, softmax_triton.shape) - print("softmax_ref:", softmax_ref, softmax_ref.shape) - torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) + # compare with pytorch expect bhsd + if layout in ["bhsd", "bshd"]: + out_pytorch, softmax_pytorch = torch.ops.aten._scaled_dot_product_attention_math( + q.transpose(1, 2) if layout == "bshd" else q , + k.transpose(1, 2) if layout == "bshd" else k, + v.transpose(1, 2) if layout == "bshd" else v, + dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + out_pytorch = out_pytorch.transpose(1, 2) if layout == "bshd" else out_pytorch + + if DEBUG: + print("o:", o, o.shape) + print("out_pytorch:", out_pytorch, out_pytorch.shape) + torch.testing.assert_close(o, out_pytorch, atol=ATOL, rtol=RTOL) + # compare with pytorch output if DEBUG: print("softmax_triton:", softmax_triton, softmax_triton.shape) @@ -446,38 +469,44 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 4, 4, 4), - (1, 1, 4, 4, 16), - (1, 1, 16, 16, 16), - (1, 1, 32, 32, 16), - (1, 1, 64, 64, 16), # pass # smallest head_size = 16 - (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 - (1, 1, 128, 128, 64), - (1, 1, 128, 256, 45), - (1, 1, 256, 256, 64), - (1, 1, 256, 512, 16), - (1, 1, 512, 512, 64), - (1, 1, 1024, 1024, 64), - # old tests that work - (4, 48, 1024, 1024, 73), - (4, 48, 1024, 1024, 64), - (4, 48, 2048, 2048, 64), - (1, 24, 4096, 4096, 64), - (1, 16, 1024, 1024, 64), - (1, 16, 1024, 1024, 128), + # (1, 1, 1, 1, 1), + # (1, 1, 4, 4, 4), + # (1, 1, 4, 4, 16), + # (1, 1, 16, 16, 16), + # (1, 1, 32, 32, 16), + # (1, 1, 64, 64, 16), # pass # smallest head_size = 16 + # (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 + # (1, 1, 128, 128, 64), + # (1, 1, 128, 256, 45), + # (1, 1, 256, 256, 64), + # (1, 1, 256, 512, 16), + # (1, 1, 512, 512, 64), + # (1, 1, 1024, 1024, 64), + # fa configs + # (2, 2, 128, 128, 65), + (2, 2, 128, 128, 224), + # (2, 2, 128, 128, 224), + # (2, 2, 108, 256, 224), + # (4, 6, 108, 256, 224), + # (1, 1, 256, 512, 16), + # # old tests that work + # (4, 48, 1024, 1024, 73), + # (4, 48, 1024, 1024, 64), + # (4, 48, 2048, 2048, 64), + # (1, 24, 4096, 4096, 64), + # (1, 16, 1024, 1024, 64), + # (1, 16, 1024, 1024, 128), # # old tests that were commented out # (1, 16, 8192, 8192, 63), # (1, 16, 1022, 1022, 64), - # bad fa configs - # (1, 1, 256, 512, 16), ]) @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize('use_exp2', [True, False]) -@pytest.mark.parametrize('bwd_preprocessing_use_o', [True, False]) +@pytest.mark.parametrize('use_exp2', [False]) +@pytest.mark.parametrize('bwd_preprocessing_use_o', [False]) +@pytest.mark.parametrize('layout', ["bhsd"]) @pytest.mark.parametrize('use_new', [True]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend -def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, use_new, DEBUG_INPUT): +def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, layout, use_new, DEBUG_INPUT): dtype = torch.float16 torch.manual_seed(20) # seed from test_op_bwd @@ -485,21 +514,36 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr sm_scale = 1 else: sm_scale = D_HEAD ** -0.5 - head_size = D_HEAD - layout = 'bhsd' alibi_slopes = None if DEBUG_INPUT: - q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, 1, N_CTX_Q, 1).expand(Z, H, N_CTX_Q, D_HEAD).contiguous().requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() - do = torch.ones_like(q).contiguous() + if layout == 'bhsd': + q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, 1, N_CTX_Q, 1).expand(Z, H, N_CTX_Q, D_HEAD).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, 1, N_CTX_K, 1).expand(Z, H, N_CTX_K, D_HEAD).contiguous().requires_grad_() + do = torch.ones_like(q).contiguous() + elif layout == "bshd": + q = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").view(1, N_CTX_Q, 1, 1).expand(Z, N_CTX_Q, H, D_HEAD).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, N_CTX_K, 1, 1).expand(Z, N_CTX_K, H, D_HEAD).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device="cuda").view(1, N_CTX_K, 1, 1).expand(Z, N_CTX_K, H, D_HEAD).contiguous().requires_grad_() + do = torch.ones_like(q).contiguous() + else: + raise ValueError("Unknown layout") else: - # Generate random inputs - q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - do = torch.randn_like(q) + if layout == 'bhsd': + # Generate random inputs + q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + do = torch.randn_like(q) + elif layout == 'bshd': + # Generate random inputs + q = torch.randn(Z, N_CTX_Q, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + k = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + v = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + do = torch.randn_like(q) + else: + raise ValueError("Unknown layout") # =============================================== Reference ============================================================== q_ref = q.clone() @@ -516,6 +560,7 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr q_ref, k_ref, v_ref, sm_scale, causal, layout, use_exp2 ) if DEBUG: + print() print("attention_scores_ref:", attention_scores_ref, attention_scores_ref.shape) print("attention_shifted_scaled_scores_ref:", attention_shifted_scaled_scores_ref, attention_shifted_scaled_scores_ref.shape) print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape) @@ -546,8 +591,8 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr ) # =============================================== Triton ============================================================== - o = o_ref.clone() - softmax_lse = softmax_lse_ref.clone() + o = o_ref.clone().contiguous() + softmax_lse = softmax_lse_ref.clone().contiguous() dq, dk, dv, delta, _, _ = attention_prefill_backward_triton_impl( do, q, @@ -559,10 +604,13 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr dk, dv, sm_scale, - head_size, alibi_slopes, causal, layout, + None, + None, + None, + None, use_exp2, bwd_preprocessing_use_o=bwd_preprocessing_use_o, use_new=use_new, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d589f4d2..3ac393f23 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -56,50 +56,38 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen return q, k, v, input_metadata -def get_shape_from_layout(q, k, metadata): - if metadata.layout == 'thd': - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_size = q.shape[-1] - batch = metadata.num_contexts - elif metadata.layout == 'bhsd': - batch, nheads_q, _, head_size = q.shape - nheads_k = k.shape[1] - elif metadata.layout == 'bshd': - batch, _, nheads_q, head_size = q.shape - nheads_k = k.shape[2] +def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + if layout == 'bhsd': + batch_q, nheads_q, seqlen_q, head_size_q = q.shape + batch_k, nheads_k, seqlen_k, head_size_k = k.shape + elif layout == 'bshd': + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + elif layout == 'thd': + batch_q, seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + batch_k, seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] else: assert False, "Got unsupported layout." - return batch, nheads_q, nheads_k, head_size - -def get_padded_headsize(size): - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (size - 1).bit_length() - # Smallest head_dim supported is 16. If smaller, the tile in the - # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) - return padded_d_model - - -def _strides(x: torch.Tensor, *stride_names: str): - if x is None: - return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + # assert + assert batch_q == batch_k + assert nheads_q == nheads_k # might not be true in mqa and gqa. Keep for now + assert head_size_q == head_size_k - assert x.ndim == len(stride_names) - return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k -# TODO: This can probably optimized to have fewer lines of code. -def get_strides_from_layout(q, k, v, o, metadata): - if metadata.layout == 'thd': +def get_strides_from_layout(q, k, v, o, layout): + if layout == 'thd': q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - elif metadata.layout == 'bhsd': + elif layout == 'bhsd': q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) - elif metadata.layout == 'bshd': + elif layout == 'bshd': q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) @@ -108,6 +96,24 @@ def get_strides_from_layout(q, k, v, o, metadata): assert False, 'Got unsupported layout.' return q_strides, k_strides, v_strides, o_strides +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + + def get_input_shapes(): cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) @@ -201,7 +207,7 @@ def need_dropout(self, dropout_p, return_scores): def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() - batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) if self.varlen: assert q.dim() == 3 assert self.cu_seqlens_q is not None diff --git a/tests/test_flash_attn_triton.py b/tests/test_flash_attn_triton.py index 818ae1222..fe862aa56 100644 --- a/tests/test_flash_attn_triton.py +++ b/tests/test_flash_attn_triton.py @@ -19,17 +19,12 @@ from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb - +DEBUG = False # Test ROCM Triton Backend USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE" if USE_TRITON_ROCM: random.seed(42) -def skip_config(**kwargs): - if 'd' in kwargs: - return random.random() < 0.20 - return False - MAX_HEADDIM_SM8x = 192 @@ -225,7 +220,7 @@ def construct_local_mask( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), col_idx < row_idx + sk - sq - window_size[0], ) -DEBUG = False + def attention_ref( q, @@ -264,15 +259,6 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ - if DEBUG: - print() - if upcast==False and reorder_ops==True: - print("attention_ref_py") - else: - print("attention_ref") - if DEBUG: - print("upcast:", upcast) - print("reorder_ops:", reorder_ops) if causal: window_size = (window_size[0], 0) dtype_og = q.dtype @@ -286,8 +272,6 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if DEBUG: - print("scores_ref:", scores) if softcap > 0: scores = scores / softcap scores = scores.tanh() @@ -307,11 +291,7 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - if DEBUG: - print("lse_ref:", torch.logsumexp(scores, dim=-1)) attention = torch.softmax(scores, dim=-1).to(v.dtype) - if DEBUG: - print("attention_ref:", attention) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -592,16 +572,16 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [False]) -@pytest.mark.parametrize("alibi", [False, True]) -# @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) @@ -609,21 +589,17 @@ def get_dropout_fraction( # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize("seqlen", [512]) -@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize("dropout_p", [0.0]) +# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if dropout_p != 0.0: pytest.skip("Dropout not supported in AMD's Triton Backend yet") if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - if skip_config(seqlen=seqlen, d=d): - pytest.skip("Skipping configuration due to limited test time") - if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -754,38 +730,33 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) -@pytest.mark.parametrize("alibi", [False, True]) -# @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if dropout_p != 0.0: pytest.skip("Dropout not supported in AMD's Triton Backend yet") if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen=seqlen, d=d): - pytest.skip("Skipping configuration due to limited test time") - if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -938,7 +909,6 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -# @pytest.mark.parametrize("d", [16]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -964,8 +934,6 @@ def test_flash_attn_output( ): if USE_TRITON_ROCM: test_backward = True - DEBUG_INPUT= False - if dropout_p != 0.0: pytest.skip("Dropout not supported on AMD's Triton Backend yet") @@ -974,10 +942,6 @@ def test_flash_attn_output( if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - # if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - # pytest.skip("Skipping configuration due to limited test time") - if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -988,16 +952,12 @@ def test_flash_attn_output( device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 1 #4 - nheads = 1 # if softcap == 0.0 else 4 # softcap reference impl takes more memory + batch_size = 4 + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if DEBUG_INPUT: - q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, seqlen_q, 1, 1).expand(batch_size, seqlen_q, nheads, d).contiguous().requires_grad_() - else: - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if softcap > 0: # Ensure the values of qk are at least within softcap range. q = q * softcap @@ -1006,16 +966,12 @@ def test_flash_attn_output( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: - if DEBUG_INPUT: - k = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, seqlen_k, 1, 1).expand(batch_size,seqlen_k, nheads_k, d).contiguous().requires_grad_() - v = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, seqlen_k, 1, 1).expand(batch_size, seqlen_k, nheads_k, d).contiguous().requires_grad_() - else: - k = torch.randn( - batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True - ) - v = torch.randn( - batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True - ) + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) @@ -1151,10 +1107,7 @@ def test_flash_attn_output( print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") - if DEBUG_INPUT: - g = torch.ones_like(out).contiguous() - else: - g = torch.randn_like(out) + g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) test_backward = test_backward and ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)) if test_backward: @@ -1205,7 +1158,6 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - DEBUG=False if DEBUG: print("out:", out, out.shape) print("out_ref:", out_ref, out_ref.shape) @@ -1222,71 +1174,66 @@ def test_flash_attn_output( print("dv:", dv, dv.shape) print("dv_ref:", dv_ref, dv_ref.shape) print("dv_pt:", dv_pt, dv_pt.shape) - # fp16 default is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html - ATOL, RTOL = 1e-4, 1e-3 - # torch.testing.assert_close(dv, dv_ref, atol=ATOL, rtol=RTOL) assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() if DEBUG: print("dk:", dk, dk.shape) print("dk_ref:", dk_ref, dk_ref.shape) print("dk_pt:", dk_pt, dk_pt.shape) - # torch.testing.assert_close(dk, dk_ref, atol=ATOL, rtol=RTOL) assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() if DEBUG: print("dq:", dq, dq.shape) print("dq_ref:", dq_ref, dq_ref.shape) print("dq_pt:", dq_pt, dq_pt.shape) - # torch.testing.assert_close(dq, dq_ref, atol=ATOL, rtol=RTOL) assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() -@pytest.mark.parametrize("kvpacked", [True, False]) +@pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize('mha_type', ["mqa"]) -@pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) -@pytest.mark.parametrize("alibi", [False, True]) -# @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize('mha_type', ["mha"]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize('d', [32]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 147), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), + (4, 4) + # (1, 147), + # (113, 203), + # (128, 217), + # (113, 211), + # (108, 256), + # (256, 512), + # (512, 256), + # (1024, 1024), + # (1023, 1024), + # (1024, 1023), + # (2048, 2048), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) -@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("softcap", [0.0, 50.0]) -# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): - if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if dropout_p != 0.0: pytest.skip("Dropout not supported in AMD's Triton Backend yet") @@ -1295,9 +1242,6 @@ def test_flash_attn_varlen_output( if softcap != 0.0: pytest.skip("softcap not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1579,18 +1523,18 @@ def test_flash_attn_varlen_output( assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [True]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1609,14 +1553,9 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1700,18 +1639,18 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [True]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1734,8 +1673,7 @@ def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype ): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") @@ -1744,10 +1682,6 @@ def test_flash_attn_varlen_causal( if seqlen_q * seqlen_k >= 256 * 512: pytest.skip(f"{seqlen_q}, {seqlen_k} leads to out of memory on AMD") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1884,24 +1818,24 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) -@pytest.mark.parametrize("alibi", [False, True]) -# @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1920,16 +1854,10 @@ def test_flash_attn_varlen_causal( def test_flash_attn_splitkv( seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype ): - if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" @@ -2102,10 +2030,6 @@ def test_flash_attn_kvcache( if has_leftpad == True: pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -2353,8 +2277,8 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) @@ -2382,14 +2306,9 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if dropout_p != 0.0: pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - device = "cuda" # set seed torch.random.manual_seed(0) @@ -2442,13 +2361,6 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ - if USE_TRITON_ROCM: - if True: - pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen=seqlen, d=d): - pytest.skip("Skipping configuration due to limited test time") - device = "cuda" # set seed torch.random.manual_seed(0) @@ -2505,13 +2417,6 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ - if USE_TRITON_ROCM: - if True: - pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen=seqlen, d=d): - pytest.skip("Skipping configuration due to limited test time") - device = "cuda" # set seed torch.random.manual_seed(0) @@ -2564,13 +2469,6 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ - if USE_TRITON_ROCM: - if True: - pytest.skip("Backward Attention not supported on AMD's Triton Backend yet") - - if skip_config(d=d): - pytest.skip("Skipping configuration due to limited test time") - device = "cuda" # set seed torch.random.manual_seed(0) @@ -2595,20 +2493,20 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -2627,14 +2525,9 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -2664,20 +2557,20 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [True]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -2696,14 +2589,9 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if USE_TRITON_ROCM: - test_backward = False - + test_backward = True if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d): - pytest.skip("Skipping configuration due to limited test time") - if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30