Skip to content

Commit

Permalink
Autotune off by default (#90)
Browse files Browse the repository at this point in the history
* Autotune off by default

* rework tests
  • Loading branch information
micmelesse authored Oct 31, 2024
1 parent 4e1d6ef commit 6345da6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ jobs:
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
pytest tests/test_flash_attn_triton_amd.py
- name: AMD Kernel Tests
- name: AMD Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Kernel Bench
- name: AMD Bench
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
- name: AMD Bench with Autotune
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
python flash_attn/flash_attn_triton_amd/bench.py
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import triton

AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '1').lower() in ('1', 'true', 'yes')
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')

Expand Down

0 comments on commit 6345da6

Please sign in to comment.