Skip to content

Commit

Permalink
fix varlen bug
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 3, 2024
1 parent b577610 commit e228683
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 37 deletions.
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ def _bwd_kernel(
delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam

if DROPOUT:
batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth + q_start * stride_dropoutm
dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth + q_start * stride_dropoutm
batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm
dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm
else:
batch_philox_offset = 0
dropout_offset = 0
Expand Down
24 changes: 12 additions & 12 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -167,11 +167,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if RETURN_SCORES:
sd_mask_ptrs += BLOCK_N
sd_mask_ptrs += BLOCK_N * stride_sn

if ENABLE_DROPOUT:
dropout_mask_ptrs += BLOCK_N
philox_ptrs += BLOCK_N
dropout_mask_ptrs += BLOCK_N * stride_sn
philox_ptrs += BLOCK_N * stride_sn
return acc, l_i, m_i


Expand Down Expand Up @@ -364,15 +364,15 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
alibi_slope = None

if RETURN_SCORES:
sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
else:
sd_mask_ptrs = None

if ENABLE_DROPOUT:
dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
else:
dropout_mask_ptrs = None
Expand Down Expand Up @@ -407,7 +407,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
Expand All @@ -432,11 +432,11 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
if USE_BIAS:
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if RETURN_SCORES:
sd_mask_ptrs += n_full_blocks * BLOCK_N
sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn
if ENABLE_DROPOUT:
dropout_mask_ptrs += n_full_blocks * BLOCK_N
philox_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn
philox_ptrs += n_full_blocks * BLOCK_N * stride_sn
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
Expand Down
46 changes: 23 additions & 23 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,8 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('kvpacked', [False])
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize('mha_type', ["mha"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("alibi", [False, True])
Expand All @@ -1205,29 +1205,29 @@ def test_flash_attn_output(
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
@pytest.mark.parametrize('d', [32])
# @pytest.mark.parametrize('d', [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(32, 32),
# (1, 147),
# (113, 203),
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
# (512, 256),
# (1024, 1024),
# (1023, 1024),
# (1024, 1023),
# (2048, 2048),
# (32, 32),
(1, 147),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_varlen_output(
Expand Down Expand Up @@ -1255,20 +1255,20 @@ def test_flash_attn_varlen_output(
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap

if kvpacked:
kv = torch.ones(
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.ones(
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.ones(
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)

Expand Down Expand Up @@ -1458,7 +1458,7 @@ def test_flash_attn_varlen_output(
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.ones_like(out)
g = torch.randn_like(out)
if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
if kvpacked:
(
Expand Down

0 comments on commit e228683

Please sign in to comment.