diff --git a/flash_attn/flash_attn_triton_kernel_decode_amd.py b/flash_attn/flash_attn_triton_kernel_decode_amd.py index a01b33ad0..7cb7c8000 100644 --- a/flash_attn/flash_attn_triton_kernel_decode_amd.py +++ b/flash_attn/flash_attn_triton_kernel_decode_amd.py @@ -113,24 +113,16 @@ def _fwd_kernel_splitK( ) # Quantization - # QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 - # PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_QUANT_GROUPS - # D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_QUANT_GROUPS - tl.static_assert(N_QUANT_GROUPS == 1, "N_QUANT_GROUPS != 1") - tl.static_assert(PACKED_PER_VAL == 1, "PACKED_PER_VAL != 1") + # TODO: enable quantization + tl.static_assert(N_QUANT_GROUPS == 1, "N_QUANT_GROUPS != 1. Quantization is not supported yet.") + tl.static_assert(PACKED_PER_VAL == 1, "PACKED_PER_VAL != 1. Quantization is not supported yet.") QUANTIZED: tl.constexpr = 0 # Padding - # print("ACTUAL_BLOCK_DMODEL:", ACTUAL_BLOCK_DMODEL) - # print("BLOCK_DMODEL:", BLOCK_DMODEL) PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) - # print("PADDED_HEAD:", PADDED_HEAD) if PADDED_HEAD: d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL - # print("d_mask:", d_mask) - - start_m = tl.program_id(0) off_zhg = tl.program_id(1) @@ -151,7 +143,6 @@ def _fwd_kernel_splitK( alibi_slope = tl.load(Alibi_slopes + a_offset) else: alibi_slope = None - # print("alibi_slope:", alibi_slope) lo = splitk_idx * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: @@ -177,7 +168,6 @@ def _fwd_kernel_splitK( v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg # Copy new Keys and Values into Cache - # print("NEW_KV", NEW_KV) if NEW_KV: knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g @@ -298,13 +288,9 @@ def _fwd_kernel_splitK( q = (q * qk_scale).to(q.dtype) if PADDED_HEAD: q = tl.where(d_mask[None, :], q, 0.0) - # print("q:", q) - # print("BLOCK_N:", BLOCK_N) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): - # print("start_n:", start_n) - k, v = load_dequantize_k_v_group( K_block_ptr, V_block_ptr, @@ -324,7 +310,6 @@ def _fwd_kernel_splitK( # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) # noqa: F821 - # print("qk before:", qk) if USE_ALIBI: row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -333,19 +318,15 @@ def _fwd_kernel_splitK( # Compute relative positions relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) relative_pos = tl.abs(relative_pos) - # print("relative_pos:", relative_pos) # Compute ALiBi bias alibi_bias = -1 * alibi_slope * relative_pos - # print("alibi_bias:", alibi_bias) qk += (alibi_bias * 1.44269504) # Apply causal mask if IS_CAUSAL is True if IS_CAUSAL: row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # print("row_idx:", row_idx) col_idx = start_n + tl.arange(0, BLOCK_N) - # print("col_idx:", col_idx) # create a N_CTX_Q x kv_len causal mask col_offset = N_CTX_Q - kv_len @@ -353,25 +334,18 @@ def _fwd_kernel_splitK( # Apply the mask qk = tl.where(causal_mask, qk, float("-inf")) - # print("qk after causal:", qk) # TODO: This is slow, and only needed at the last iteration. # Maybe we can unroll the last iteration instead? if BOUNDS_CHECKS_N: qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - # print("qk after BOUNDS_CHECKS_N:", qk) # -- compute scaling constant --- - # print("m_i:", m_i) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - # print("m_i_new:", m_i_new) if IS_CAUSAL: alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) else: alpha = tl.math.exp2(m_i - m_i_new) - - # print("alpha:", alpha) - # print("before qk - m_i_new:", qk) # cause of nan because subtracting infs if IS_CAUSAL: qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) @@ -379,7 +353,6 @@ def _fwd_kernel_splitK( qk = qk - m_i_new[:, None] p = tl.math.exp2(qk) - # print("p:", p) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) @@ -389,7 +362,6 @@ def _fwd_kernel_splitK( # -- scale and update acc -- acc *= alpha[:, None] acc += tl.dot(p.to(v.dtype), v) - # print("acc:", acc) # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) @@ -559,19 +531,13 @@ def _splitK_reduce( l_sum = tl.load(Metadata_ptr + stride_m2) acc = tl.load(o_ptr) - # print("l_m:", l_m) - # print("l_sum:", l_sum) - # print("acc:", acc) - g_m = tl.max(l_m, axis=0) - # print("g_m:", g_m) if IS_CAUSAL: l_m_offset = l_m - g_m alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) else: alpha = tl.math.exp2(l_m - g_m) - # print("alpha:", alpha) # read sum l_sum *= alpha @@ -760,7 +726,6 @@ def forward(cls, q, k, v, input_metadata): dim_k = (dim_k - cls.NUM_QUANT_GROUPS) * 8 assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" - # print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, heads_per_group_q = {heads_per_group_q}, heads_per_group_k = {heads_per_group_k}, dim_q = {dim_q}, dim_k = {dim_k}") BLOCK_M = cls.BLOCK_M BLOCK_N = cls.BLOCK_N diff --git a/flash_attn/flash_attn_triton_kernel_prefill_amd.py b/flash_attn/flash_attn_triton_kernel_prefill_amd.py index 25e9859ac..3a061d07e 100644 --- a/flash_attn/flash_attn_triton_kernel_prefill_amd.py +++ b/flash_attn/flash_attn_triton_kernel_prefill_amd.py @@ -187,15 +187,6 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): return tensor -@triton.jit -def print_gpu(prefix, val=None): - if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): - if val is not None: - tl.device_print(prefix, val) - else: - tl.device_print(prefix) - - @triton.jit def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix @@ -348,12 +339,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=4), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - # num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # TODO: This configs fails with head_size not pow2 with data mismatches. figure out why + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], use_cuda_graph=True,