From 6345da6672de32ac8543761c809797508c82991a Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 31 Oct 2024 16:39:11 -0400 Subject: [PATCH] Autotune off by default (#90) * Autotune off by default * rework tests --- .github/workflows/amd_tests.yml | 11 +++++++---- flash_attn/flash_attn_triton_amd/utils.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index b1441e4c7..5618a08cc 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -66,14 +66,17 @@ jobs: - name: Flash Attention Tests run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py - - name: AMD Kernel Tests + - name: AMD Tests run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -v -s flash_attn/flash_attn_triton_amd/test.py - - name: AMD Kernel Bench + - name: AMD Bench run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + python flash_attn/flash_attn_triton_amd/bench.py + - name: AMD Bench with Autotune + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index b59486495..530455063 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -3,7 +3,7 @@ import os import triton -AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '1').lower() in ('1', 'true', 'yes') +AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')