Skip to content

Commit

Permalink
target MI300 directly
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 21, 2025
1 parent e728cab commit d1b6fd9
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
echo '::set-output name=matrix-HIP::[["linux-mi300-gpu-1"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down Expand Up @@ -59,13 +59,13 @@ jobs:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
- name: Flash Attention Tests Using Reference Impl
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner[0] == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention CDNA Tests
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner[0] == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
Expand All @@ -75,17 +75,17 @@ jobs:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output tests/test_flash_attn_triton_amd.py::test_flash_attn_varlen_output tests/test_flash_attn_triton_amd.py::test_flash_attn_kvcache
- name: AMD Tests
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner[0] == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Bench
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner[0] == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
- name: AMD Bench with Autotune
if: matrix.runner[1] == 'gfx90a'
if: matrix.runner[0] == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
Expand Down

0 comments on commit d1b6fd9

Please sign in to comment.