Skip to content

Commit

Permalink
one test skip most
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 27, 2024
1 parent 1226d35 commit 9d002b9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 31 deletions.
59 changes: 31 additions & 28 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm", "gfx942"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down Expand Up @@ -56,34 +56,37 @@ jobs:
- name: Build
run: |
python setup.py install
- name: Flash Attention qkvpacked Tests
- name: Flash Attention Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked
pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked
- name: Flash Attention output Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
- name: Flash Attention causal Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_causal
pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
- name: Flash Attention kvcache Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_kvcache
pytest tests/test_flash_attn.py::test_flash_attn_splitkv
- name: Flash Attention race condition Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_race_condition
- name: Flash Attention bwd Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_bwd_overflow
pytest tests/test_flash_attn.py::test_flash_attn_bwd_transpose
pytest tests/test_flash_attn.py::test_flash_attn_bwd_varlen_overflow
- name: Flash Attention deterministic Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_deterministic
pytest tests/test_flash_attn.py::test_flash_attn_varlen_deterministic
pytest tests/test_flash_attn.py
# - name: Flash Attention qkvpacked Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked
# - name: Flash Attention output Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
# - name: Flash Attention causal Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_causal
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
# - name: Flash Attention kvcache Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_kvcache
# pytest tests/test_flash_attn.py::test_flash_attn_splitkv
# - name: Flash Attention race condition Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_race_condition
# - name: Flash Attention bwd Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_overflow
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_transpose
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_varlen_overflow
# - name: Flash Attention deterministic Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_deterministic
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_deterministic
- name: AMD Kernel Tests
run: |
pytest flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import triton.language as tl
from flash_attn.flash_attn_triton_kernel_prefill_amd import MetaData

DEBUG = True
DEBUG = False

def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb

DEBUG = True
DEBUG = False

MAX_HEADDIM_SM8x = 192

Expand All @@ -34,7 +34,7 @@ def is_amd():
return True
return False

def skip_config(*args, reproducible=True, skip_pct = 0.95):
def skip_config(*args, reproducible=True, skip_pct = 0.99):
config_str = '_'.join(map(str, args))

if reproducible:
Expand Down

0 comments on commit 9d002b9

Please sign in to comment.