Skip to content

Commit

Permalink
use descale to set fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 16, 2025
1 parent 6e2dcbf commit db4a331
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ jobs:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention Tests
- name: Flash Attention CDNA Tests
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output tests/test_flash_attn_triton_amd.py::test_flash_attn_varlen_output tests/test_flash_attn_triton_amd.py::test_flash_attn_kvcache
- name: Flash Attention RDNA Tests
if: matrix.runner[1] == 'gfx1100'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
Expand Down
10 changes: 3 additions & 7 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,8 @@ def attention_prefill_forward_triton_impl(
print("return_scores:", return_softmax)
print("use_exp2:", use_exp2)

# Define FP8 types we support
FP8_TYPES = {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}

# Simple check if tensors are FP8
is_fp8 = q.dtype in FP8_TYPES

if is_fp8:
if descale_q is not None:
is_fp8 = True
if DEBUG:
print("IS_FP8")

Expand Down Expand Up @@ -674,6 +669,7 @@ def attention_prefill_forward_triton_impl(
p_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device)
acc_fp8 = torch.zeros(o.shape, dtype=torch.float32, device=q.device)
else:
is_fp8 = False
# For non-FP8 types, use dummy values (no scaling needed)
descale_q = descale_k = descale_v = descale_s = 1
descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = descale_s_stride_z = 0
Expand Down

0 comments on commit db4a331

Please sign in to comment.