forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up for Upstream #81
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
micmelesse
added a commit
that referenced
this pull request
Oct 14, 2024
* Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride
micmelesse
added a commit
that referenced
this pull request
Oct 28, 2024
Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes
micmelesse
added a commit
that referenced
this pull request
Oct 28, 2024
Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup
micmelesse
added a commit
that referenced
this pull request
Oct 28, 2024
Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel
rocking5566
pushed a commit
that referenced
this pull request
Dec 17, 2024
* Enable Fwd and Backward Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel * clean up * Enable sequence_parallel in bwd (#89) * sequence_parallel working on bwd_impl test * fix qkv error * save * save * save * bwd 3 times faster * clean up * fix varlen bug * use copy back dict * fix qkvpacked bug * reduce bench sizes * print copy back * clean more * Autotune off by default * update Triton commit readme (#92)
micmelesse
added a commit
that referenced
this pull request
Feb 20, 2025
Enable Fwd and Backward Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel Enable sequence_parallel in bwd (#89) * sequence_parallel working on bwd_impl test * fix qkv error * save * save * save * bwd 3 times faster * clean up * fix varlen bug * use copy back dict * fix qkvpacked bug * reduce bench sizes * print copy back Autotune off by default (#90) * Autotune off by default * rework tests Update Triton Version (#91) * ignore ck code * update triton update Triton commit readme (#92) Fix README (#96) * Update README.md * fix readme Enable MQA/GQA in backward (#100) * simple failing test * ref is working * fix bug * save * find failing case * fowrad varlen mqa/gqa works * add mqa configs to bwd test * varlen bwd ref fixed * save failing case * GQA flag * ones passes * go back to values * save * bhsd working with mqa * remove repo * test layouts * clean up * test back to normal * clean up more * use zeros_like * zero out Added Support for Rotary Positional Embeddings (#99) * feat: added rotary support in kvcache * confirmed non-fused rotary passes all tests add RDNA CI (#105) * Add RDNA CI This is a combination of 4 commits. try navi try matrix small change try minimal change * limit navi tests * stop casting to fp32 which leads to oom on navi * enable all causal * revert all causal * skip compiler bug on navi Dropout (#101) * Alex's work This is a combination of 11 commits. save fix: dropout=0.0 woorks feat: dropout restrictions removed. failing tests test: reduced tests to simple cases test: failure is due to query + key padding mask NOT varlen itself feat: varlen dropout fwd passes fix: varlen bwd dropout works! test: discovered bwd error for non-dropout cases for large seqlen save save use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes * Almost Everything works. This is a combination of 16 commits. Work so far This is a combination of 63 commits. pick test case save philox offsets into metadata pass offset to ref common dropout mask simple droput out mask start dropout ref. work on returning SD_Mask next with negative numbers refernce is working dropout bwd ref faling case transfer rng_state properly save changes one dropout mask function save save minizmize diff save use torch.where in backward save save save dk works! passes reference is working. TODO" attn_ref is broken varlen ref working attn failing case with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv. save skip attn matrices compare the masks and find failing case rm cdiv_fn put dropout and alibi in common save compare masks save save pytorch ref is using tiles save save tl_rand_ref cache ref dropout mask new generate_dropout_mask_ref using tiling issolate failing varlen case simple dropout loop on k print rng_outputs save fwd kernel works save dv passed close to dk simple ref save seperate droped and scaled in ref and triton kernel ref changes working delta with dp find failing dv failures find failing case due to delta save delta from dp working bwd impl green enable test fwd save save delete kernels save probably mask application mismatch dump forward dropout pass dropout mask tensor to bwd_core different dropout fraction in fwd and bwd mismatch found on columns greater than 64 fix dropout bug. philox was not offset run full suite stop debug and approximate delta fix drop_mask non issue skip attn check clean up common bad varlen config fix varlen bug save * fix datatype mismatch * clean up * use pytorch dropout * It works on MI300. * remove _bwd_preprocess_use_p * fix torch interface bug --------- Co-authored-by: Alex Kranias <[email protected]> fp8 forward (#116) * disable navi * start test * test fp16 against fp8 * save scaling code so far * global scaling * add per_head_scaling * dump qk * save dumping q, k and qk to fp32 tensor * fix pointer bug * save reproducer * dump p and acc * fp8 working with my debug input * save * change api for dequant * pass descale_p * clean up * most working * save * save * varlen half way * some varlen examples work * improve varlen debug input * varlen mostly working * push working cases * fix ref bug * fix backward bug * fix varlen backward bug * use descale to set fp8 * check arch fp8 support * cache arch * try again * skip bad config on MI200 * skip decode nan config on MI200 * fix mistake * skip more * run full suit * Update amd_tests.yml * address comments * navi ci is broken * raise error tolerance to 2.5e-1 * target MI300 directly * show gfx * try again * don't fail matrix if one path fails * try upstream triton * just get MI300 working * Fix install bug This is a combination of 5 commits. try this use --no-build-isolation put route at .python run full suite remove triton * run ref on cpu * move ref test to navi machines * pin triton * add bench deps Update readme Minor fixes (#107) * Clean up This is a combination of 4 commits. update base image disable navi for now all causal seems to work on MI300 skip MI200 causal bugs * remove MI200 skips * just run on prs or manually * add navi back * try again * update readme * mark flakey test * ref bug Performant backward Triton implementation with separated dkdv and dq kernels (#122) * added the split file * overhauled split file, need to add new kernels * copied triton fa over for reference * added comments * preprocess and dkdv done * fixed dkdv, added dq * fixed assumption on q, kv length different, run but incorrect * added standalone test for split bwd kernel * minor change on the ptr arith * separated the dkdv and dq kernels * GQA works now, onto seqlen q != k * dk,dq working, dv still failing * fixed the masking and num_step calc, now q==k works * added debug print with interpreter, might not work entirely w/o next commit * fixed all issues with q != k * fixed varlen issue * fixup on debug print * fixed dropout, esp w/ varlen * added USE_EXP2 toggle * added noncausal kernel * updated internal test for noncausal and use_exp2 * formatting * fixed dropout from seed bug * added envvar USE_SPLIT to toggle btw bwd kernels * fixed the qkv pack issue and removed hack * added the split kernel into interface_fa.py * change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default * removed redundant file * fixed missing import in test * fixed import in interface_fa.py * revert changes in flash_attn_interface.py * updated strides to adapt to various tensor init shape * fixed issue that dqkv not zero'd * disabled the AMD local test Quick Fixes (#124) * fix fp8 bug * fix type bug * forgot nones * docker file reenable gfx1100 ci (#121) * reenable * randomly sample * clean up ci * add pytest-randomly * try again update triton commit (#128) * update triton commit * disable navi update base docker image (#129) fp8 BWD after figuring out varlen problem This is a combination of 21 commits. fp8 BWD Enable BWD fp8 with split kernel Enable BWD fp8 with per block scale factors for p and ds This is a combination of 9 commits. Enable BWD fp8 This is a combination of 12 commits. add backward test case save clean up disable ci lse is good dv matches reduce diff use do fp8 for dv kinda working group size is a constexpr clean up a bit everything except mqa/gqa works skip mqa cases 20 cases have nan on dropout save what you have disable tests failing enable tests per block descale_p and descale_ds use max(abs(()) clean up tests a bit more fix bug disable ci for now pass variables add flags add alternate path. Still need to load descale factors dv working dk works save add type info for backward fix DEBUG flag bug fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides. pass descale strides test causal fix causal compiler assert. min head should be 32 remove descale_p save explict name as causal isolate bad case just run fp8 tests bench with autotune min changes cast_fp8 helper cast_varlen_to_fp8 save minor highlight failing configs increase test cases mark failing recategorize misc tests group failing gqa configs add more tests add vis code min ci changes dump folder single image per tensors add tensor comparison gen varlen tensor vis varlen tensors varlen diff nice varlen vis vis function show seqlen in varlen add vis_tensors function simplify add color bars rm vis from test set canvas size. descale values are optional add ck tests add flag to build ck rm ck test assert requires grad ensure q, k, and v require gradients split vis rm interp, 8k and 300 dpi slice per page disable ci for now add more vis code tensor per image is better for vis_close, don't vis if no error. also vis all failing varlen tests varlen failures due to different seqlens rm vis code
micmelesse
added a commit
that referenced
this pull request
Feb 20, 2025
Enable Fwd and Backward Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel Enable sequence_parallel in bwd (#89) * sequence_parallel working on bwd_impl test * fix qkv error * save * save * save * bwd 3 times faster * clean up * fix varlen bug * use copy back dict * fix qkvpacked bug * reduce bench sizes * print copy back Autotune off by default (#90) * Autotune off by default * rework tests Update Triton Version (#91) * ignore ck code * update triton update Triton commit readme (#92) Fix README (#96) * Update README.md * fix readme Enable MQA/GQA in backward (#100) * simple failing test * ref is working * fix bug * save * find failing case * fowrad varlen mqa/gqa works * add mqa configs to bwd test * varlen bwd ref fixed * save failing case * GQA flag * ones passes * go back to values * save * bhsd working with mqa * remove repo * test layouts * clean up * test back to normal * clean up more * use zeros_like * zero out Added Support for Rotary Positional Embeddings (#99) * feat: added rotary support in kvcache * confirmed non-fused rotary passes all tests add RDNA CI (#105) * Add RDNA CI This is a combination of 4 commits. try navi try matrix small change try minimal change * limit navi tests * stop casting to fp32 which leads to oom on navi * enable all causal * revert all causal * skip compiler bug on navi Dropout (#101) * Alex's work This is a combination of 11 commits. save fix: dropout=0.0 woorks feat: dropout restrictions removed. failing tests test: reduced tests to simple cases test: failure is due to query + key padding mask NOT varlen itself feat: varlen dropout fwd passes fix: varlen bwd dropout works! test: discovered bwd error for non-dropout cases for large seqlen save save use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes * Almost Everything works. This is a combination of 16 commits. Work so far This is a combination of 63 commits. pick test case save philox offsets into metadata pass offset to ref common dropout mask simple droput out mask start dropout ref. work on returning SD_Mask next with negative numbers refernce is working dropout bwd ref faling case transfer rng_state properly save changes one dropout mask function save save minizmize diff save use torch.where in backward save save save dk works! passes reference is working. TODO" attn_ref is broken varlen ref working attn failing case with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv. save skip attn matrices compare the masks and find failing case rm cdiv_fn put dropout and alibi in common save compare masks save save pytorch ref is using tiles save save tl_rand_ref cache ref dropout mask new generate_dropout_mask_ref using tiling issolate failing varlen case simple dropout loop on k print rng_outputs save fwd kernel works save dv passed close to dk simple ref save seperate droped and scaled in ref and triton kernel ref changes working delta with dp find failing dv failures find failing case due to delta save delta from dp working bwd impl green enable test fwd save save delete kernels save probably mask application mismatch dump forward dropout pass dropout mask tensor to bwd_core different dropout fraction in fwd and bwd mismatch found on columns greater than 64 fix dropout bug. philox was not offset run full suite stop debug and approximate delta fix drop_mask non issue skip attn check clean up common bad varlen config fix varlen bug save * fix datatype mismatch * clean up * use pytorch dropout * It works on MI300. * remove _bwd_preprocess_use_p * fix torch interface bug --------- Co-authored-by: Alex Kranias <[email protected]> fp8 forward (#116) * disable navi * start test * test fp16 against fp8 * save scaling code so far * global scaling * add per_head_scaling * dump qk * save dumping q, k and qk to fp32 tensor * fix pointer bug * save reproducer * dump p and acc * fp8 working with my debug input * save * change api for dequant * pass descale_p * clean up * most working * save * save * varlen half way * some varlen examples work * improve varlen debug input * varlen mostly working * push working cases * fix ref bug * fix backward bug * fix varlen backward bug * use descale to set fp8 * check arch fp8 support * cache arch * try again * skip bad config on MI200 * skip decode nan config on MI200 * fix mistake * skip more * run full suit * Update amd_tests.yml * address comments * navi ci is broken * raise error tolerance to 2.5e-1 * target MI300 directly * show gfx * try again * don't fail matrix if one path fails * try upstream triton * just get MI300 working * Fix install bug This is a combination of 5 commits. try this use --no-build-isolation put route at .python run full suite remove triton * run ref on cpu * move ref test to navi machines * pin triton * add bench deps Update readme Minor fixes (#107) * Clean up This is a combination of 4 commits. update base image disable navi for now all causal seems to work on MI300 skip MI200 causal bugs * remove MI200 skips * just run on prs or manually * add navi back * try again * update readme * mark flakey test * ref bug Performant backward Triton implementation with separated dkdv and dq kernels (#122) * added the split file * overhauled split file, need to add new kernels * copied triton fa over for reference * added comments * preprocess and dkdv done * fixed dkdv, added dq * fixed assumption on q, kv length different, run but incorrect * added standalone test for split bwd kernel * minor change on the ptr arith * separated the dkdv and dq kernels * GQA works now, onto seqlen q != k * dk,dq working, dv still failing * fixed the masking and num_step calc, now q==k works * added debug print with interpreter, might not work entirely w/o next commit * fixed all issues with q != k * fixed varlen issue * fixup on debug print * fixed dropout, esp w/ varlen * added USE_EXP2 toggle * added noncausal kernel * updated internal test for noncausal and use_exp2 * formatting * fixed dropout from seed bug * added envvar USE_SPLIT to toggle btw bwd kernels * fixed the qkv pack issue and removed hack * added the split kernel into interface_fa.py * change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default * removed redundant file * fixed missing import in test * fixed import in interface_fa.py * revert changes in flash_attn_interface.py * updated strides to adapt to various tensor init shape * fixed issue that dqkv not zero'd * disabled the AMD local test Quick Fixes (#124) * fix fp8 bug * fix type bug * forgot nones * docker file reenable gfx1100 ci (#121) * reenable * randomly sample * clean up ci * add pytest-randomly * try again update triton commit (#128) * update triton commit * disable navi update base docker image (#129) fp8 BWD after figuring out varlen problem This is a combination of 21 commits. fp8 BWD Enable BWD fp8 with split kernel Enable BWD fp8 with per block scale factors for p and ds This is a combination of 9 commits. Enable BWD fp8 This is a combination of 12 commits. add backward test case save clean up disable ci lse is good dv matches reduce diff use do fp8 for dv kinda working group size is a constexpr clean up a bit everything except mqa/gqa works skip mqa cases 20 cases have nan on dropout save what you have disable tests failing enable tests per block descale_p and descale_ds use max(abs(()) clean up tests a bit more fix bug disable ci for now pass variables add flags add alternate path. Still need to load descale factors dv working dk works save add type info for backward fix DEBUG flag bug fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides. pass descale strides test causal fix causal compiler assert. min head should be 32 remove descale_p save explict name as causal isolate bad case just run fp8 tests bench with autotune min changes cast_fp8 helper cast_varlen_to_fp8 save minor highlight failing configs increase test cases mark failing recategorize misc tests group failing gqa configs add more tests add vis code min ci changes dump folder single image per tensors add tensor comparison gen varlen tensor vis varlen tensors varlen diff nice varlen vis vis function show seqlen in varlen add vis_tensors function simplify add color bars rm vis from test set canvas size. descale values are optional add ck tests add flag to build ck rm ck test assert requires grad ensure q, k, and v require gradients split vis rm interp, 8k and 300 dpi slice per page disable ci for now add more vis code tensor per image is better for vis_close, don't vis if no error. also vis all failing varlen tests varlen failures due to different seqlens rm vis code
micmelesse
added a commit
that referenced
this pull request
Feb 20, 2025
Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi, this is a pr to add a Triton backend to Flash Attention on ROCm. We hope that this pr will be the first in a series of prs to that end. Triton has had support for ROCm for a while now and a Flash Attention Triton backend will allows us to support Flash Attention on both our MI and Navi Machines.
In this pr, we enable major parts of
fwd
,varlen_fwd
andfwd_kvcache
. However there are some features missing such as Dropout, Sliding window, Rotary Embedding and Pagged Attention. There are also a few miscellaneous bugs. These will all be addressed in subsequent prs. The next pr we plan to file will be support forbwd
andvarlen_vwd
, if we should reprioritize, please let us know.We have tested this pr here on an MI200 machine. When the testing the Triton Backend for ROCm, we skip testing the backward pass, configs with unsupported features and a portion of headsizes (d) randomly. The later is to keep the test times reasonable. The latest results, we have are
=== 64 failed, 30387 passed, 478321 skipped, 1 warning in 3110.86s (0:51:50) ===
. There is clearly more work to be done but we hope that this will make a good start.Please let us know what we can do on our end to help with this process.
Finally this pr includes work from multiple people besides myself, especially thanks to @vgokhale, @scxiao and @jlgreathouse.