Skip to content

Commit

Permalink
Clean up for Upstream (#81)
Browse files Browse the repository at this point in the history
* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride
  • Loading branch information
micmelesse authored Sep 4, 2024
1 parent 7b8a15c commit 75b5360
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 986 deletions.
41 changes: 6 additions & 35 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf, micmelesse/upstream_pr]

concurrency:
group: ${{ github.ref }}
Expand Down Expand Up @@ -56,40 +56,11 @@ jobs:
cd ..
- name: Build
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
python setup.py install
# - name: Flash Attention Mini Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# - name: Flash Attention qkvpacked Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked
# - name: Flash Attention output Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
# - name: Flash Attention causal Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_causal
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
# - name: Flash Attention kvcache Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_kvcache
# pytest tests/test_flash_attn.py::test_flash_attn_splitkv
# - name: Flash Attention race condition Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_race_condition
# - name: Flash Attention bwd Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_overflow
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_transpose
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_varlen_overflow
# - name: Flash Attention deterministic Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_deterministic
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_deterministic
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
pytest tests/test_flash_attn.py
- name: AMD Kernel Tests
run: |
Expand Down
8 changes: 1 addition & 7 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@ var/
*.egg-info/
.installed.cfg
*.egg
.eggs

# IDE-related
.idea/

# Dev
venv

# AMD
.eggs
.vscode
core
scripts
log*
*csv
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports:
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

### AMD ROCm Support
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.

**Requirements:**
- ROCm 6.0 and above.
Expand All @@ -121,10 +121,33 @@ We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.

FlashAttention-2 with ROCm currently supports:
#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
#### Triton Backend
FlashAttention-2 ROCm Triton backend is a work in progress.
It current supports Forwards only. However some features like PagedAttention and Sliding Window are missing. It can run on both MI and Navi Machines. We are working on backwards.

Inorder to use the triton backend for rocm, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_ROCM` set to `"TRUE"`.

```
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
```


## How to use FlashAttention
Expand Down
11 changes: 4 additions & 7 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

import torch
import torch.nn as nn

def is_hip():
if torch.version.hip is not None:
return True
return False
import os

# isort: off
# We need to import the CUDA kernels after importing torch
if is_hip():
from . import flash_attn_triton_interface_amd as flash_attn_gpu
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from flash_attn import flash_attn_triton_interface_amd as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

Expand Down
Loading

0 comments on commit 75b5360

Please sign in to comment.