Skip to content

Commit

Permalink
remove skip config code
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 3, 2024
1 parent 8110e1d commit 7594991
Showing 1 changed file with 0 additions and 49 deletions.
49 changes: 0 additions & 49 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb

def is_power_of_2(n):
return n > 0 and (n & (n - 1)) == 0

def skip_config(**kwargs):
if 'd' in kwargs:
return not is_power_of_2(kwargs['d'])
return False

def is_hip():
if torch.version.hip is not None:
return True
Expand Down Expand Up @@ -606,9 +598,6 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen=seqlen, d=d):
pytest.skip("Skipping configuration due to limited test time")

if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = "cuda"
Expand Down Expand Up @@ -768,8 +757,6 @@ def test_flash_attn_varlen_qkvpacked(
if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen=seqlen, d=d):
pytest.skip("Skipping configuration due to limited test time")

if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
Expand Down Expand Up @@ -954,9 +941,6 @@ def test_flash_attn_output(
if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -1239,9 +1223,6 @@ def test_flash_attn_varlen_output(

if softcap != 0.0:
pytest.skip("softcap not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
Expand Down Expand Up @@ -1557,9 +1538,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):

if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
Expand Down Expand Up @@ -1689,9 +1667,6 @@ def test_flash_attn_varlen_causal(
if seqlen_q * seqlen_k >= 256 * 512:
pytest.skip(f"{seqlen_q}, {seqlen_k} leads to out of memory on AMD")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -1864,15 +1839,12 @@ def test_flash_attn_varlen_causal(
def test_flash_attn_splitkv(
seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
):

if is_hip():
test_backward = False

if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
Expand Down Expand Up @@ -2047,9 +2019,6 @@ def test_flash_attn_kvcache(
if has_leftpad == True:
pytest.skip("cache_leftpad not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
Expand Down Expand Up @@ -2330,9 +2299,6 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty

if dropout_p != 0.0:
pytest.skip("Dropout not supported in AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

device = "cuda"
# set seed
Expand Down Expand Up @@ -2389,9 +2355,6 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
if is_hip():
if True:
pytest.skip("Backward Attention not supported on AMD yet")

if skip_config(seqlen=seqlen, d=d):
pytest.skip("Skipping configuration due to limited test time")

device = "cuda"
# set seed
Expand Down Expand Up @@ -2452,9 +2415,6 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
if is_hip():
if True:
pytest.skip("Backward Attention not supported on AMD yet")

if skip_config(seqlen=seqlen, d=d):
pytest.skip("Skipping configuration due to limited test time")

device = "cuda"
# set seed
Expand Down Expand Up @@ -2512,9 +2472,6 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
if True:
pytest.skip("Backward Attention not supported on AMD yet")

if skip_config(d=d):
pytest.skip("Skipping configuration due to limited test time")

device = "cuda"
# set seed
torch.random.manual_seed(0)
Expand Down Expand Up @@ -2576,9 +2533,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -2645,9 +2599,6 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")

if skip_config(seqlen_q=seqlen_q, seqlen_k=seqlen_k, d=d):
pytest.skip("Skipping configuration due to limited test time")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down

0 comments on commit 7594991

Please sign in to comment.