Skip to content

Commit

Permalink
fix varlen in64 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 19, 2024
1 parent 2db6502 commit d192946
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 79 deletions.
149 changes: 108 additions & 41 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, sm_scale: tl.constexpr, USE_EXP2: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
RETURN_SCORES: tl.constexpr):
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
Expand Down Expand Up @@ -128,7 +128,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * sm_scale
qk_scaled = qk * SM_SCALE
if RETURN_SCORES:
score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(score_ptrs, qk_scaled, mask=score_mask)
Expand Down Expand Up @@ -235,16 +235,15 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, return_scores: tl.constexpr, USE_ALIBI: tl.constexpr,
USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr):
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand All @@ -254,6 +253,8 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
# print("cu_seqlens_q_start:", cu_seqlens_q_start)

seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
Expand Down Expand Up @@ -316,6 +317,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,
off_h_k = off_h_q

n_extra_tokens = 0
# print("n_extra_tokens:", n_extra_tokens)
# print("seqlen_k:", seqlen_k)
# print("BLOCK_N:", BLOCK_N)
# return
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
Expand Down Expand Up @@ -400,7 +405,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, sm_scale, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
block_min = block_max
block_max = n_blocks * BLOCK_N

Expand All @@ -426,9 +431,11 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL, sm_scale, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
# epilogue
acc = acc / l_i[:, None]
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
Expand Down Expand Up @@ -462,8 +469,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,

if IS_CAUSAL:
# zero out nans caused by -infs when doing causal
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
lse_mask = mask_m_offsets < causal_start_idx
lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx
softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)

# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
Expand All @@ -487,30 +493,64 @@ def attn_fwd(Q, K, V, bias, sm_scale, LSE, Out, stride_qz, stride_qh, stride_qm,
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)


def attention_prefill_forward_triton_impl(q, k, v, o, metadata):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (metadata.bias is not None):
assert (metadata.bias.numel() < 2**31)
def attention_prefill_forward_triton_impl_explicit(
q,
k,
v,
o,
sm_scale,
alibi_slopes,
causal,
bias,
dropout_p,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
return_scores,
use_exp2):

if DEBUG:
print()
print("attention_prefill_forward_triton_impl_explicit")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("sm_scale:", sm_scale)
print("alibi_slopes:", alibi_slopes)
print("causal:", causal)
print("bias:", bias)
print("dropout_p:", dropout_p)
print("layout:", layout)
print("cu_seqlens_q:", cu_seqlens_q)
print("cu_seqlens_k:", cu_seqlens_k)
print("max_seqlens_q:", max_seqlens_q)
print("max_seqlens_k:", max_seqlens_k)
print("return_scores:", return_scores)
print("use_exp2:", use_exp2)

if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.check_args(q, k, v, o)

batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata.layout)
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (bias is not None):
assert (bias.numel() < 2**31)

batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)

# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
# Smallest head_dim supported is 16. If smaller, the tile in the
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)

grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch)
grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch)

if metadata.return_scores:
scores = torch.zeros((batch, nheads_q, metadata.max_seqlens_q, metadata.max_seqlens_k), device=q.device,
if return_scores:
scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_scaled_shifted = torch.zeros((batch, nheads_q, metadata.max_seqlens_q, metadata.max_seqlens_k), device=q.device,
scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
else:
Expand All @@ -522,39 +562,66 @@ def attention_prefill_forward_triton_impl(q, k, v, o, metadata):
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if metadata.return_scores:
exp_scores = torch.zeros((batch, nheads_q, metadata.max_seqlens_q, metadata.max_seqlens_k), device=q.device,
if return_scores:
exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q,max_seqlens_k), device=q.device,
dtype=torch.float32)
else:
exp_scores = None

# stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
softmax_lse = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32)
softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)

# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42

if metadata.bias is not None:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2),
metadata.bias.stride(3))
if bias is not None:
bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
bias.stride(3))
else:
bias_strides = (0, 0, 0, 0)

if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1))
if alibi_slopes is not None:
alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)

attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k,
dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True,
USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p
> 0.0, return_scores=metadata.return_scores,
USE_EXP2=metadata.use_exp2, RETURN_SCORES=metadata.return_scores)

is_varlen = layout == "thd"
print("cu_seqlens_q:", cu_seqlens_q)
print("cu_seqlens_k:", cu_seqlens_k)
attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)

return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted


def attention_prefill_forward_triton_impl(q, k, v, o, metadata):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.check_args(q, k, v, o)

return attention_prefill_forward_triton_impl_explicit(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import math

DEBUG = True
DEBUG = False

def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
if DEBUG:
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .fwd_decode import attention_decode_forward_triton_impl
from .utils import MetaData, get_shape_from_layout

DEBUG = True
DEBUG = False

def fwd(q,
k,
Expand Down
77 changes: 44 additions & 33 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .bwd_ref import attention_backward_pytorch_ref_impl
from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4

DEBUG = True
DEBUG = False

# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html
ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose.
Expand Down Expand Up @@ -147,10 +147,21 @@ def test_op_fwd_prefill_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, d
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64),
(4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64),
(4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128),
(4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 8192, 64),
# (4, 48, 256, 64),
# (4, 48, 512, 64),
# (4, 48, 1024, 64),
# (8, 48, 4096, 64),
# (4, 48, 8192, 64),
# (4, 48, 128, 128),
# (4, 48, 4096, 128),
# (4, 48, 16384, 128),
# (4, 16, 1024, 128),
# (4, 16, 8192, 128),
# (32, 48, 8192, 128)
]
)
@pytest.mark.parametrize('causal', [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):

Expand Down Expand Up @@ -331,38 +342,38 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali


@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
# (1, 1, 1, 1, 1),
# (1, 1, 2, 4, 16),
# (1, 1, 4, 2, 16),
# (1, 1, 4, 4, 16),
# (1, 2, 4, 4, 16),
# (2, 1, 4, 4, 16),
# (2, 2, 4, 4, 16),
(1, 1, 1, 1, 1),
(1, 1, 2, 4, 16),
(1, 1, 4, 2, 16),
(1, 1, 4, 4, 16),
(1, 2, 4, 4, 16),
(2, 1, 4, 4, 16),
(2, 2, 4, 4, 16),
(1, 1, 128, 64, 16),
# (2, 2, 2, 128, 1),
# (2, 3, 2, 128, 16),
# (3, 2, 256, 512, 16),
# (3, 3, 128, 128, 64),
# (2, 4, 1024, 1024, 64),
# (4, 6, 108, 256, 224),
# (4, 8, 2048, 2048, 128),
# (4, 16, 4096, 4096, 64),
# (2, 4, 8192, 8192, 32),
(2, 2, 2, 128, 1),
(2, 3, 2, 128, 16),
(3, 2, 256, 512, 16),
(3, 3, 128, 128, 64),
(2, 4, 1024, 1024, 64),
(4, 6, 108, 256, 224),
(4, 8, 2048, 2048, 128),
(4, 16, 4096, 4096, 64),
(2, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 113, 203, 256),
# (4, 6, 128, 217, 256),
# (4, 6, 113, 211, 128),
# (4, 6, 108, 256, 128),
# (4, 6, 256, 512, 64),
# (4, 6, 512, 256, 64),
# (4, 6, 1024, 1024, 32),
# (4, 6, 1023, 1024, 32),
# (4, 6, 1024, 1023, 32),
# (4, 6, 2048, 2048, 32),
(4, 6, 113, 203, 256),
(4, 6, 128, 217, 256),
(4, 6, 113, 211, 128),
(4, 6, 108, 256, 128),
(4, 6, 256, 512, 64),
(4, 6, 512, 256, 64),
(4, 6, 1024, 1024, 32),
(4, 6, 1023, 1024, 32),
(4, 6, 1024, 1023, 32),
(4, 6, 2048, 2048, 32),
])
@pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('return_scores', [False])
@pytest.mark.parametrize('layout', ["bhsd"])
@pytest.mark.parametrize('layout', ["thd"])
@pytest.mark.parametrize('use_exp2', [False]) # works when use_exp2 is false
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues
def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT):
Expand Down
Loading

0 comments on commit d192946

Please sign in to comment.