Skip to content

Commit

Permalink
varlen ref working
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Nov 20, 2024
1 parent 7a4eafe commit a247186
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 52 deletions.
56 changes: 32 additions & 24 deletions flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox
print("o:", o, o.shape)

return o, softmax_lse, sd_mask, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores

def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2):
"""Compute reference output and softmax_lse using PyTorch's built-in function"""

Expand Down Expand Up @@ -193,7 +194,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout
if layout == "bshd":
o = o.transpose(1, 2)

return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
return o, softmax_lse, exp_scores, None, None, None, None


def attention_varlen_forward_pytorch_ref_impl(
Expand Down Expand Up @@ -222,9 +223,11 @@ def attention_varlen_forward_pytorch_ref_impl(

# Pre-allocate outputs
total_L_q = q.shape[0]
total_L_k = k.shape[0]

o = torch.empty((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device)
softmax_lse = torch.empty((total_L_q, nheads_q), dtype=torch.float32, device=q.device)
sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device)

# Compute group_size for MQA/GQA handling
group_size = nheads_q // nheads_k
Expand Down Expand Up @@ -275,12 +278,12 @@ def attention_varlen_forward_pytorch_ref_impl(
(
o_i,
softmax_lse_i,
exp_scores_i,
softmax_i,
attention_shifted_scaled_scores_i,
attention_scaled_scores_i,
attention_scores_i,
) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, use_exp2)
sd_mask_i,
_,
_,
_,
_,
) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2)

# Reshape outputs back to original dimensions
if group_size != 1:
Expand All @@ -298,15 +301,20 @@ def attention_varlen_forward_pytorch_ref_impl(
# Convert back to 'thd' layout and float16
o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, nheads_q, head_dim]
softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q]
sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i]

# print("sd_mask_i: ", sd_mask_i)

# Place outputs in pre-allocated tensors
o[start_q:end_q, :, :] = o_i
softmax_lse[start_q:end_q, :] = softmax_lse_i

sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i

return (
o,
softmax_lse,
None,
sd_mask,
None,
None,
None,
Expand Down Expand Up @@ -354,11 +362,11 @@ def attention_forward_pytorch_ref_impl(
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
sd_mask_ref,
_,
_,
_,
_,
) = attention_varlen_forward_pytorch_ref_impl(
q.clone(),
k.clone(),
Expand All @@ -379,11 +387,11 @@ def attention_forward_pytorch_ref_impl(
(
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
sd_mask_ref,
_,
_,
_,
_,
) = attention_vanilla_forward_pytorch_ref_impl(q.clone(),
k.clone(),
v.clone(),
Expand All @@ -400,16 +408,16 @@ def attention_forward_pytorch_ref_impl(
print("attention_forward_pytorch_ref_impl outputs")
print("o_ref:", o_ref, o_ref.shape)
print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None)
print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape if sd_mask_ref is not None else None)

return (
o_ref,
softmax_lse_ref,
exp_scores_ref,
softmax_ref,
attention_shifted_scaled_scores_ref,
attention_scaled_scores_ref,
attention_scores_ref,
sd_mask_ref,
_,
_,
_,
_,
)


Expand Down
38 changes: 25 additions & 13 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ def varlen_fwd(
print("window_size_right:", window_size_right)
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")

if o is None:
o = torch.empty_like(q)

Expand All @@ -309,6 +306,9 @@ def varlen_fwd(

if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast
else:
rng_state = None

# Check arguments
metadata.check_args(q, k, v, o)
Expand All @@ -320,7 +320,7 @@ def varlen_fwd(
print("Using reference implementation")
(output,
softmax_lse,
exp_scores,
sd_mask,
_,
_,
_,
Expand All @@ -335,14 +335,17 @@ def varlen_fwd(
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.dropout_p,
metadata.philox_seed,
metadata.philox_offset,
metadata.use_exp2)
o.copy_(output)
else:
if DEBUG:
print("Using Triton implementation")
(_,
softmax_lse,
exp_scores,
sd_mask,
_,
_,
_,
Expand All @@ -356,23 +359,25 @@ def varlen_fwd(
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.bias,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.max_seqlens_k,
metadata.dropout_p,
metadata.philox_seed,
metadata.philox_offset,
metadata.return_scores,
metadata.use_exp2)
if DEBUG:
print("varlen_fwd outputs")
print("o:", o, o.shape)
print("softmax_lse:", softmax_lse, softmax_lse.shape)
print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )
print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None )


return o, softmax_lse, exp_scores, None
return o, softmax_lse, sd_mask, rng_state

def varlen_bwd(
dout,
Expand Down Expand Up @@ -426,9 +431,10 @@ def varlen_bwd(
print("gen_:", gen_)
print("rng_state:", rng_state)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")

if dropout_p > 0.0:
philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item()
else:
philox_seed, philox_offset = None, None
if USE_REF:
if DEBUG:
print("Using reference implementation")
Expand All @@ -446,6 +452,9 @@ def varlen_bwd(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
philox_seed,
philox_offset,
False,
)
dq.copy_(dq_ref)
Expand Down Expand Up @@ -473,6 +482,9 @@ def varlen_bwd(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
philox_seed,
philox_offset,
False,
)
delta = delta_triton
Expand Down
22 changes: 7 additions & 15 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,9 +593,6 @@ def get_dropout_fraction(
@pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if USE_TRITON_ROCM:
if dropout_p != 0.0:
pytest.skip("Dropout not supported in AMD's Triton Backend yet")

if local == True:
pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")

Expand Down Expand Up @@ -753,9 +750,6 @@ def test_flash_attn_varlen_qkvpacked(
seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
):
if USE_TRITON_ROCM:
if dropout_p != 0.0:
pytest.skip("Dropout not supported in AMD's Triton Backend yet")

if local == True:
pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
Expand Down Expand Up @@ -891,8 +885,8 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("deterministic", [False])
Expand Down Expand Up @@ -951,8 +945,8 @@ def test_flash_attn_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 1
nheads = 1 if softcap == 0.0 else 4 # softcap reference impl takes more memory
batch_size = 4
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
Expand Down Expand Up @@ -1222,6 +1216,7 @@ def test_flash_attn_output(
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
# (4, 4),
(1, 147),
(113, 203),
(128, 217),
Expand All @@ -1237,16 +1232,13 @@ def test_flash_attn_output(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('dropout_p', [0.17])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if USE_TRITON_ROCM:
if dropout_p != 0.0:
pytest.skip("Dropout not supported in AMD's Triton Backend yet")

if local == True:
pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")

Expand Down Expand Up @@ -1526,7 +1518,7 @@ def test_flash_attn_varlen_output(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
Expand Down

0 comments on commit a247186

Please sign in to comment.