Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up for Upstream #81

Merged
merged 14 commits into from
Sep 4, 2024
Merged

Clean up for Upstream #81

merged 14 commits into from
Sep 4, 2024

Conversation

micmelesse
Copy link
Collaborator

Hi, this is a pr to add a Triton backend to Flash Attention on ROCm. We hope that this pr will be the first in a series of prs to that end. Triton has had support for ROCm for a while now and a Flash Attention Triton backend will allows us to support Flash Attention on both our MI and Navi Machines.

In this pr, we enable major parts of fwd, varlen_fwd and fwd_kvcache. However there are some features missing such as Dropout, Sliding window, Rotary Embedding and Pagged Attention. There are also a few miscellaneous bugs. These will all be addressed in subsequent prs. The next pr we plan to file will be support for bwd and varlen_vwd, if we should reprioritize, please let us know.

We have tested this pr here on an MI200 machine. When the testing the Triton Backend for ROCm, we skip testing the backward pass, configs with unsupported features and a portion of headsizes (d) randomly. The later is to keep the test times reasonable. The latest results, we have are === 64 failed, 30387 passed, 478321 skipped, 1 warning in 3110.86s (0:51:50) ===. There is clearly more work to be done but we hope that this will make a good start.

Please let us know what we can do on our end to help with this process.
Finally this pr includes work from multiple people besides myself, especially thanks to @vgokhale, @scxiao and @jlgreathouse.

@micmelesse micmelesse changed the title Micmelesse/upstream pr fork Clean up for Upstream Sep 4, 2024
@micmelesse micmelesse marked this pull request as ready for review September 4, 2024 12:21
@micmelesse micmelesse merged commit 75b5360 into main_perf Sep 4, 2024
4 of 7 checks passed
@micmelesse micmelesse deleted the micmelesse/upstream_pr_fork branch September 4, 2024 12:22
micmelesse added a commit that referenced this pull request Oct 14, 2024
* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel
rocking5566 pushed a commit that referenced this pull request Dec 17, 2024
* Enable Fwd and Backward

Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel

* clean up

* Enable sequence_parallel in bwd (#89)

* sequence_parallel working on bwd_impl test

* fix qkv error

* save

* save

* save

* bwd 3 times faster

* clean up

* fix varlen bug

* use copy back dict

* fix qkvpacked bug

* reduce bench sizes

* print copy back

* clean more

* Autotune off by default

* update Triton commit readme (#92)
micmelesse added a commit that referenced this pull request Feb 20, 2025
Enable Fwd and Backward

Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel

Enable sequence_parallel in bwd (#89)

* sequence_parallel working on bwd_impl test

* fix qkv error

* save

* save

* save

* bwd 3 times faster

* clean up

* fix varlen bug

* use copy back dict

* fix qkvpacked bug

* reduce bench sizes

* print copy back

Autotune off by default (#90)

* Autotune off by default

* rework tests

Update Triton Version (#91)

* ignore ck code

* update triton

update Triton commit readme (#92)

Fix README (#96)

* Update README.md

* fix readme

Enable MQA/GQA in backward (#100)

* simple failing test

* ref is working

* fix bug

* save

* find failing case

* fowrad varlen mqa/gqa works

* add mqa configs to bwd test

* varlen bwd ref fixed

* save failing case

* GQA flag

* ones passes

* go back to values

* save

* bhsd working with mqa

* remove repo

* test layouts

* clean up

* test back to normal

* clean up more

* use zeros_like

* zero out

Added Support for Rotary Positional Embeddings (#99)

* feat: added rotary support in kvcache

* confirmed non-fused rotary passes all tests

add RDNA CI (#105)

* Add RDNA CI

This is a combination of 4 commits.

try navi

try matrix

small change

try minimal change

* limit navi tests

* stop casting to fp32 which leads to oom on navi

* enable all causal

* revert all causal

* skip compiler bug on navi

Dropout (#101)

* Alex's work

This is a combination of 11 commits.

save

fix: dropout=0.0 woorks

feat: dropout restrictions removed. failing tests

test: reduced tests to simple cases

test: failure is due to query + key padding mask NOT varlen itself

feat: varlen dropout fwd passes

fix: varlen bwd dropout works!

test: discovered  bwd error for non-dropout cases for large seqlen

save

save

use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes

* Almost Everything works.

This is a combination of 16 commits.

Work so far

This is a combination of 63 commits.

pick test case

save philox offsets into metadata

pass offset to ref

common dropout mask

simple droput out mask

start dropout ref. work on returning SD_Mask next with negative numbers

refernce is working

dropout bwd ref faling case

transfer rng_state properly

save changes

one dropout mask function

save

save

minizmize diff

save

use torch.where in backward

save

save

save

dk works!

passes

reference is working. TODO" attn_ref is broken

varlen ref working

attn failing case

with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv.

save

skip attn matrices

compare the masks and find failing case

rm cdiv_fn

put dropout and alibi in common

save

compare masks

save

save

pytorch ref is using tiles

save

save

tl_rand_ref

cache ref dropout mask

new generate_dropout_mask_ref using tiling

issolate failing varlen case

simple dropout

loop on k

print rng_outputs

save

fwd kernel works

save

dv passed

close to dk

simple ref

save

seperate droped and scaled in ref and triton kernel

ref changes

working delta with dp

find failing dv failures

find failing case due to delta

save

delta from dp working

bwd impl green

enable test fwd

save

save

delete kernels

save

probably mask application mismatch

dump forward dropout

pass dropout mask tensor to bwd_core

different dropout fraction in fwd and bwd

mismatch found on columns greater than 64

fix dropout bug. philox was not offset

run full suite

stop debug and approximate delta

fix drop_mask non issue

skip attn check

clean up common

bad varlen config

fix varlen bug

save

* fix datatype mismatch

* clean up

* use pytorch dropout

* It works on MI300.

* remove _bwd_preprocess_use_p

* fix torch interface bug

---------

Co-authored-by: Alex Kranias <[email protected]>

fp8 forward (#116)

* disable navi

* start test

* test fp16 against fp8

* save scaling code so far

* global scaling

* add per_head_scaling

* dump qk

* save dumping q, k and qk to fp32 tensor

* fix pointer bug

* save reproducer

* dump p and acc

* fp8 working with my debug input

* save

* change api for dequant

* pass descale_p

* clean up

* most working

* save

* save

* varlen half way

* some varlen examples work

* improve varlen debug input

* varlen mostly working

* push working cases

* fix ref bug

* fix backward bug

* fix varlen backward bug

* use descale to set fp8

* check arch fp8 support

* cache arch

* try again

* skip bad config on MI200

* skip decode nan config on MI200

* fix mistake

* skip more

* run full  suit

* Update amd_tests.yml

* address comments

* navi ci is broken

* raise error tolerance to 2.5e-1

* target MI300 directly

* show gfx

* try again

* don't fail matrix if one path fails

* try upstream triton

* just get MI300 working

* Fix install bug

This is a combination of 5 commits.

try this

use --no-build-isolation

put route at .python

run full suite

remove triton

* run ref on cpu

* move ref test to navi machines

* pin triton

* add bench deps

Update readme

Minor fixes (#107)

* Clean up

This is a combination of 4 commits.

update base image

disable navi for now

all causal seems to work on MI300

skip MI200 causal bugs

* remove MI200 skips

* just run on prs or manually

* add navi back

* try again

* update readme

* mark flakey test

* ref bug

Performant backward Triton implementation with separated dkdv and dq kernels (#122)

* added the split file

* overhauled split file, need to add new kernels

* copied triton fa over for reference

* added comments

* preprocess and dkdv done

* fixed dkdv, added dq

* fixed assumption on q, kv length different, run but incorrect

* added standalone test for split bwd kernel

* minor change on the ptr arith

* separated the dkdv and dq kernels

* GQA works now, onto seqlen q != k

* dk,dq working, dv still failing

* fixed the masking and num_step calc, now q==k works

* added debug print with interpreter, might not work entirely w/o next commit

* fixed all issues with q != k

* fixed varlen issue

* fixup on debug print

* fixed dropout, esp w/ varlen

* added USE_EXP2 toggle

* added noncausal kernel

* updated internal test for noncausal and use_exp2

* formatting

* fixed dropout from seed bug

* added envvar USE_SPLIT to toggle btw bwd kernels

* fixed the qkv pack issue and removed hack

* added the split kernel into interface_fa.py

* change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default

* removed redundant file

* fixed missing import in test

* fixed import in interface_fa.py

* revert changes in flash_attn_interface.py

* updated strides to adapt to various tensor init shape

* fixed issue that dqkv not zero'd

* disabled the AMD local test

Quick Fixes (#124)

* fix fp8 bug

* fix type bug

* forgot nones

* docker file

reenable gfx1100 ci (#121)

* reenable

* randomly sample

* clean up ci

* add pytest-randomly

* try again

update triton commit (#128)

* update triton commit

* disable navi

update base docker image (#129)

fp8 BWD after figuring out varlen problem

This is a combination of 21 commits.

fp8 BWD

Enable BWD fp8 with split kernel

Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save

add type info for backward

fix  DEBUG flag bug

fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides.

pass descale strides

test causal

fix causal compiler assert. min head should be 32

remove descale_p

save

explict name as causal

isolate bad case

just run fp8 tests

bench with autotune

min changes

cast_fp8 helper

cast_varlen_to_fp8

save

minor

highlight failing configs

increase test cases

mark failing

recategorize misc tests

group failing gqa configs

add more tests

add vis code

min ci changes

dump folder

single image per tensors

add tensor comparison

gen varlen tensor

vis varlen tensors

varlen diff

nice varlen vis

vis function

show seqlen in varlen

add vis_tensors function

simplify

add color bars

rm vis from test

set canvas size.

descale values are optional

add ck tests

add flag to build ck

rm ck test

assert requires grad

ensure q, k, and v require gradients

split vis

rm interp, 8k and 300 dpi

slice per page

disable ci for now

add more vis code

tensor per image is better

for vis_close, don't vis if no error. also vis all failing varlen tests

varlen failures due to different seqlens

rm vis code
micmelesse added a commit that referenced this pull request Feb 20, 2025
Enable Fwd and Backward

Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel

Enable sequence_parallel in bwd (#89)

* sequence_parallel working on bwd_impl test

* fix qkv error

* save

* save

* save

* bwd 3 times faster

* clean up

* fix varlen bug

* use copy back dict

* fix qkvpacked bug

* reduce bench sizes

* print copy back

Autotune off by default (#90)

* Autotune off by default

* rework tests

Update Triton Version (#91)

* ignore ck code

* update triton

update Triton commit readme (#92)

Fix README (#96)

* Update README.md

* fix readme

Enable MQA/GQA in backward (#100)

* simple failing test

* ref is working

* fix bug

* save

* find failing case

* fowrad varlen mqa/gqa works

* add mqa configs to bwd test

* varlen bwd ref fixed

* save failing case

* GQA flag

* ones passes

* go back to values

* save

* bhsd working with mqa

* remove repo

* test layouts

* clean up

* test back to normal

* clean up more

* use zeros_like

* zero out

Added Support for Rotary Positional Embeddings (#99)

* feat: added rotary support in kvcache

* confirmed non-fused rotary passes all tests

add RDNA CI (#105)

* Add RDNA CI

This is a combination of 4 commits.

try navi

try matrix

small change

try minimal change

* limit navi tests

* stop casting to fp32 which leads to oom on navi

* enable all causal

* revert all causal

* skip compiler bug on navi

Dropout (#101)

* Alex's work

This is a combination of 11 commits.

save

fix: dropout=0.0 woorks

feat: dropout restrictions removed. failing tests

test: reduced tests to simple cases

test: failure is due to query + key padding mask NOT varlen itself

feat: varlen dropout fwd passes

fix: varlen bwd dropout works!

test: discovered  bwd error for non-dropout cases for large seqlen

save

save

use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes

* Almost Everything works.

This is a combination of 16 commits.

Work so far

This is a combination of 63 commits.

pick test case

save philox offsets into metadata

pass offset to ref

common dropout mask

simple droput out mask

start dropout ref. work on returning SD_Mask next with negative numbers

refernce is working

dropout bwd ref faling case

transfer rng_state properly

save changes

one dropout mask function

save

save

minizmize diff

save

use torch.where in backward

save

save

save

dk works!

passes

reference is working. TODO" attn_ref is broken

varlen ref working

attn failing case

with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv.

save

skip attn matrices

compare the masks and find failing case

rm cdiv_fn

put dropout and alibi in common

save

compare masks

save

save

pytorch ref is using tiles

save

save

tl_rand_ref

cache ref dropout mask

new generate_dropout_mask_ref using tiling

issolate failing varlen case

simple dropout

loop on k

print rng_outputs

save

fwd kernel works

save

dv passed

close to dk

simple ref

save

seperate droped and scaled in ref and triton kernel

ref changes

working delta with dp

find failing dv failures

find failing case due to delta

save

delta from dp working

bwd impl green

enable test fwd

save

save

delete kernels

save

probably mask application mismatch

dump forward dropout

pass dropout mask tensor to bwd_core

different dropout fraction in fwd and bwd

mismatch found on columns greater than 64

fix dropout bug. philox was not offset

run full suite

stop debug and approximate delta

fix drop_mask non issue

skip attn check

clean up common

bad varlen config

fix varlen bug

save

* fix datatype mismatch

* clean up

* use pytorch dropout

* It works on MI300.

* remove _bwd_preprocess_use_p

* fix torch interface bug

---------

Co-authored-by: Alex Kranias <[email protected]>

fp8 forward (#116)

* disable navi

* start test

* test fp16 against fp8

* save scaling code so far

* global scaling

* add per_head_scaling

* dump qk

* save dumping q, k and qk to fp32 tensor

* fix pointer bug

* save reproducer

* dump p and acc

* fp8 working with my debug input

* save

* change api for dequant

* pass descale_p

* clean up

* most working

* save

* save

* varlen half way

* some varlen examples work

* improve varlen debug input

* varlen mostly working

* push working cases

* fix ref bug

* fix backward bug

* fix varlen backward bug

* use descale to set fp8

* check arch fp8 support

* cache arch

* try again

* skip bad config on MI200

* skip decode nan config on MI200

* fix mistake

* skip more

* run full  suit

* Update amd_tests.yml

* address comments

* navi ci is broken

* raise error tolerance to 2.5e-1

* target MI300 directly

* show gfx

* try again

* don't fail matrix if one path fails

* try upstream triton

* just get MI300 working

* Fix install bug

This is a combination of 5 commits.

try this

use --no-build-isolation

put route at .python

run full suite

remove triton

* run ref on cpu

* move ref test to navi machines

* pin triton

* add bench deps

Update readme

Minor fixes (#107)

* Clean up

This is a combination of 4 commits.

update base image

disable navi for now

all causal seems to work on MI300

skip MI200 causal bugs

* remove MI200 skips

* just run on prs or manually

* add navi back

* try again

* update readme

* mark flakey test

* ref bug

Performant backward Triton implementation with separated dkdv and dq kernels (#122)

* added the split file

* overhauled split file, need to add new kernels

* copied triton fa over for reference

* added comments

* preprocess and dkdv done

* fixed dkdv, added dq

* fixed assumption on q, kv length different, run but incorrect

* added standalone test for split bwd kernel

* minor change on the ptr arith

* separated the dkdv and dq kernels

* GQA works now, onto seqlen q != k

* dk,dq working, dv still failing

* fixed the masking and num_step calc, now q==k works

* added debug print with interpreter, might not work entirely w/o next commit

* fixed all issues with q != k

* fixed varlen issue

* fixup on debug print

* fixed dropout, esp w/ varlen

* added USE_EXP2 toggle

* added noncausal kernel

* updated internal test for noncausal and use_exp2

* formatting

* fixed dropout from seed bug

* added envvar USE_SPLIT to toggle btw bwd kernels

* fixed the qkv pack issue and removed hack

* added the split kernel into interface_fa.py

* change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default

* removed redundant file

* fixed missing import in test

* fixed import in interface_fa.py

* revert changes in flash_attn_interface.py

* updated strides to adapt to various tensor init shape

* fixed issue that dqkv not zero'd

* disabled the AMD local test

Quick Fixes (#124)

* fix fp8 bug

* fix type bug

* forgot nones

* docker file

reenable gfx1100 ci (#121)

* reenable

* randomly sample

* clean up ci

* add pytest-randomly

* try again

update triton commit (#128)

* update triton commit

* disable navi

update base docker image (#129)

fp8 BWD after figuring out varlen problem

This is a combination of 21 commits.

fp8 BWD

Enable BWD fp8 with split kernel

Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save

add type info for backward

fix  DEBUG flag bug

fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides.

pass descale strides

test causal

fix causal compiler assert. min head should be 32

remove descale_p

save

explict name as causal

isolate bad case

just run fp8 tests

bench with autotune

min changes

cast_fp8 helper

cast_varlen_to_fp8

save

minor

highlight failing configs

increase test cases

mark failing

recategorize misc tests

group failing gqa configs

add more tests

add vis code

min ci changes

dump folder

single image per tensors

add tensor comparison

gen varlen tensor

vis varlen tensors

varlen diff

nice varlen vis

vis function

show seqlen in varlen

add vis_tensors function

simplify

add color bars

rm vis from test

set canvas size.

descale values are optional

add ck tests

add flag to build ck

rm ck test

assert requires grad

ensure q, k, and v require gradients

split vis

rm interp, 8k and 300 dpi

slice per page

disable ci for now

add more vis code

tensor per image is better

for vis_close, don't vis if no error. also vis all failing varlen tests

varlen failures due to different seqlens

rm vis code
micmelesse added a commit that referenced this pull request Feb 20, 2025
Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant