Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Draft
wants to merge 34 commits into
base: main_perf
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
156a9bc
Enable BWD fp8 with split kernel
micmelesse Jan 24, 2025
e6a67b3
add type info for backward
micmelesse Feb 6, 2025
c13c0f0
fix DEBUG flag bug
micmelesse Feb 6, 2025
acb05ad
fix bug with backward. Normal forward works with dropout. Segfault wi…
micmelesse Feb 6, 2025
ca267ed
pass descale strides
micmelesse Feb 6, 2025
00d1c6f
test causal
micmelesse Feb 6, 2025
3ba93db
fix causal compiler assert. min head should be 32
micmelesse Feb 6, 2025
3694224
remove descale_p
micmelesse Feb 6, 2025
7908150
save
micmelesse Feb 7, 2025
86fd7e6
explict name as causal
micmelesse Feb 7, 2025
01b370a
isolate bad case
micmelesse Feb 7, 2025
290d594
just run fp8 tests
micmelesse Feb 7, 2025
e9e4d6e
bench with autotune
micmelesse Feb 7, 2025
736a990
min changes
micmelesse Feb 7, 2025
9fc0d0a
cast_fp8 helper
micmelesse Feb 10, 2025
db15c3d
cast_varlen_to_fp8
micmelesse Feb 10, 2025
32d552e
save
micmelesse Feb 10, 2025
fbb00d6
minor
micmelesse Feb 10, 2025
9615417
highlight failing configs
micmelesse Feb 11, 2025
49e7db9
increase test cases
micmelesse Feb 11, 2025
7025eee
mark failing
micmelesse Feb 11, 2025
9238939
recategorize misc tests
micmelesse Feb 11, 2025
1714559
group failing gqa configs
micmelesse Feb 11, 2025
53ca7f9
add more tests
micmelesse Feb 12, 2025
51862e0
add vis code
micmelesse Feb 12, 2025
0fd32e4
min ci changes
micmelesse Feb 12, 2025
a05b551
dump folder
micmelesse Feb 12, 2025
e4d1385
single image per tensors
micmelesse Feb 12, 2025
d6f58e7
add tensor comparison
micmelesse Feb 12, 2025
ebafaff
gen varlen tensor
micmelesse Feb 13, 2025
504344c
vis varlen tensors
micmelesse Feb 13, 2025
174d8cb
varlen diff
micmelesse Feb 13, 2025
4714776
nice varlen vis
micmelesse Feb 13, 2025
2924726
vis function
micmelesse Feb 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,33 @@ jobs:
python setup.py install

# CDNA Tests
- name: Flash Attention Tests Using Reference Impl
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention CDNA Tests
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
# - name: Flash Attention Tests Using Reference Impl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: My knowledge of GitHub actions is almost none, so please take this comment with a grain of salt... As far as I can see, MI300 integration job is commented out. Am I correct? Do we really want to merge this way?

# if: matrix.runner == 'linux-mi300-gpu-1'
# run: |
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# export FLASH_ATTENTION_TRITON_AMD_REF=1
# pytest tests/test_flash_attn_triton_amd.py
# - name: Flash Attention CDNA Tests
# if: matrix.runner == 'linux-mi300-gpu-1'
# run: |
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# pytest tests/test_flash_attn_triton_amd.py
- name: AMD Tests
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest -v -s flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_fp8 flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_varlen_fp8
- name: AMD Bench
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
- name: AMD Bench with Autotune
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
python flash_attn/flash_attn_triton_amd/bench.py
# - name: AMD Bench
# if: matrix.runner == 'linux-mi300-gpu-1'
# run: |
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# python flash_attn/flash_attn_triton_amd/bench.py
# - name: AMD Bench with Autotune
# if: matrix.runner == 'linux-mi300-gpu-1'
# run: |
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
# python flash_attn/flash_attn_triton_amd/bench.py

# RDNA Tests
- name: Flash Attention RDNA Tests
Expand Down
60 changes: 39 additions & 21 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def _flash_attn_forward(
return_softmax: bool,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None
descale_v: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
Expand All @@ -113,8 +112,7 @@ def _flash_attn_forward(
None,
descale_q,
descale_k,
descale_v,
descale_p
descale_v
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -175,7 +173,6 @@ def _flash_attn_varlen_forward(
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
Expand All @@ -202,8 +199,7 @@ def _flash_attn_varlen_forward(
None,
descale_q,
descale_k,
descale_v,
descale_p
descale_v
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -273,6 +269,10 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +301,10 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_do
)
return softmax_d

Expand Down Expand Up @@ -369,6 +373,10 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -402,6 +410,10 @@ def _flash_attn_varlen_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_do
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -823,7 +835,7 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -846,10 +858,9 @@ def forward(
return_softmax=return_softmax and dropout_p > 0,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_v=descale_v
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_do)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +873,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,6 +898,10 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down Expand Up @@ -917,7 +932,7 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -945,11 +960,10 @@ def forward(
block_table=block_table,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p
descale_v=descale_v
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_do
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
Expand All @@ -965,7 +979,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(2)
dout_padded = dout
Expand Down Expand Up @@ -994,6 +1008,10 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down Expand Up @@ -1151,7 +1169,7 @@ def flash_attn_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1216,7 +1234,7 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
)


Expand Down Expand Up @@ -1396,7 +1414,7 @@ def flash_attn_varlen_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1473,7 +1491,7 @@ def flash_attn_varlen_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
)


Expand Down
3 changes: 3 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Inside the docker, it should open to the flash attention repo with everything in
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.
micmelesse marked this conversation as resolved.
Show resolved Hide resolved

##### Credits
AMD Triton kernels team

Expand Down
Loading
Loading