Skip to content

Commit

Permalink
attn failing case
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Nov 20, 2024
1 parent a247186 commit c59b7b8
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,8 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("deterministic", [False])
Expand All @@ -897,27 +897,27 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [16])
@pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
# (16, 16),
# (256, 256),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
(256, 256),
# (113, 203),
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
# (512, 256),
# (1024, 1024),
# (1023, 1024),
# (1024, 1023),
# (2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Expand Down Expand Up @@ -945,24 +945,24 @@ def test_flash_attn_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
batch_size = 1
nheads = 1 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
q = torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
kv = torch.ones(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
k = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
v = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
if alibi:
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def test_flash_attn_output(
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
g = torch.ones_like(out)
do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
if kvpacked:
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def test_flash_attn_output(
if DEBUG:
print("attn:", attn, attn.shape)
print("attn_ref:", attn_ref, attn_ref.shape)
# assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
if DEBUG:
Expand Down Expand Up @@ -1216,7 +1216,6 @@ def test_flash_attn_output(
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
# (4, 4),
(1, 147),
(113, 203),
(128, 217),
Expand All @@ -1231,8 +1230,8 @@ def test_flash_attn_output(
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_varlen_output(
Expand Down Expand Up @@ -1518,7 +1517,7 @@ def test_flash_attn_varlen_output(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

if dropout_p > 0.0:
# assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
Expand Down

0 comments on commit c59b7b8

Please sign in to comment.