-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug: Flash Attention performs worse under ROCM #10439
Comments
It's a known issue and cause by the HIP port of the CUDA FlashAttention kernel for large batch sizes having extremely poor performance (so the kernel optimized for small batch sizes is used instead). With the current code this issue cannot be fixed. |
Are there any plans to rewrite that code to be optimized for ROCM instead of a CUDA port? |
No. There currently isn't a llama.cpp/GGML dev working specifically on AMD performance or even support. I am writing a lot of CUDA code but the extent of effort that I am willing to invest is make sure that the HIP port doesn't break and determining the comparatively best code paths for AMD. |
Why does the performance also regress when FA is enabled when not using batching? That's also a bit weird given the fact it's optimized for small batch sizes. Large batch sizes hasn't really been that important in the past. Hobbyists usually do not use parallelism as they have single-user use cases most of the times. But with the introduction of speculative decoding which will hopefully land in Llama server in the future as well, performance for larger batch sizes will become important even for single user use cases. What will it take for optimizations for AMD to be considered? Will it take someone to join the project to develop and maintain these optimizations? Or might a shift towards more optimization for AMD also be made when (if?) AMD starts to become more popular among the end users? |
The code is optimized for small batch sizes and NVIDIA GPUs, if you use HIP to translate it for AMD you still pay a performance penalty vs. using an external BLAS library optimized for AMD.
My personal goal with llama.cpp is to reduce the cost at which the largest models can be run at reasonable speeds. As of right now I think there simply isn't any AMD hardware that would be worth buying over second-hand NVIDIA hardware (RTX 3090/P40). I have seen some second-hand AMD GPUs such as the Mi 60 at very competitive prices on ebay but unfortunately the supply seems to be extremely limited. If AMD were to release a consumer GPU with a high VRAM capacity at a sufficiently low price I would start optimizing performance for that GPU (even though the AMD dev tools are worse or nonexistent). If a new dev were to join the project with an interest in improving AMD performance I would be happy to assist them. |
Understandable, thanks for the detailed explanation! Hoping to see other devs join the project to optimize AMD performance then :) By the way, do you think (second hand) 7900 xtx could potentially be cost competitive to second hand 3090 and 4090 GPUs? The memory bandwidth is very similar between those GPUs. |
It has become better compared to 1-2 years ago but at least in my region (Germany) there currently don't really seem to be good second-hand offers for RX 7900 XTX cards. |
You may try my forked branch that enables rocWMMA on top of the current CUDA WMMA flash attention implementation, and I'm actively rebasing the latest master branch for my own usage: https://github.com/hjc4869/llama.cpp From my own testing it improves performance by quite a bit on RDNA3 with higher batch size, though still not optimal comparing to equivalent NVIDIA GPUs. Flash attention (master branch)./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240 main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32
Flash attention (WMMA patched)make GGML_HIPBLAS=1 GGML_CUDA_FA_ALL_QUANTS=1 AMDGPU_TARGETS=gfx1100,gfx1101 -j64 ./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240 main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32
Flash attention off./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -npl 1,8,16 -npp 512 -ntg 128 -c 10240 main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32
|
Wow! I am seeing very good results here. Some observations:
A question: I saw that your branch also introduced the option to compile with Master branch:Single Bench Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no
build: a5e4759 (4150) Batched bench FA off main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12
ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error Batched bench FA on main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12
Your branchSingle bench Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no
build: 8739ed4 (4154) Batched bench FA off main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12
ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error Batched bench FA on main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12
|
GGML_CUDA_FA_ALL_QUANTS is not related to the issue. It's to compile FA kernel for all the KV cache quantization combinations. For example if you want -ctk q4_0 in combination with -ctv q8_0 (different quantization type), or those rarely used ones like q4_1, q5_0, q5_1. |
Thanks for your explanation. Do you have any intention of trying to upstream these FA changes? |
Nope, actually this has been done before and @JohannesGaessler had some explanation here: #7011 Although the current patch did improve performance to some degree, it is still not optimized for AMD hardware. I think we should try to push AMD to officially take the responsibility to maintain ROCm/HIP support in llama.cpp just like MUSA/CANN, rather than relying on the community to port code optimized for NVIDIA. But these days they seem to be primarily focusing on contributing to the vLLM project rather than things like llama.cpp that is easy to use on client devices/Radeon cards, unfortunately. |
Didn't that comment in that PR mention they would be open to merging improvements, as long as someone would commit to maintaining that code? Not sure if that would be an option for you. Is there some way we can reach out to AMD and convince them to commit to performance improvements and maintaining them for llamacpp? |
I'm quite new in this area and not a GPU programming / machine learning expert. I do have some contacts in AMD but they're quite busy these days and diverting resources to a new project would probably need to get attention from executives, of course that's not what I could do. AMD themselves are actively using llama.cpp in their marketing materials, and Xilinx also have a forked branch internally for their AIE NPUs. But getting actual engineering resources to work on this project wouldn't be so easy. |
That's unfortunate to hear. I really think the 7900xtx could be competitive both on cost and performance. It just needs the support to get there. The situation has already improved massively compared to 1-2 years ago though. So perhaps in the future we'll get there. |
A new 7900xtx is pretty much the same price (or reasonably close) as a used 3090 so doesn't that mean AMD cards are competitive? That used NV market is going to dry up at some point and NV does not care to make affordable cards that have enough vram for llms. |
The problem is that 7900 XTX doesn't have the performance of RTX 3090 in LLMs, either on paper or in real world testing. I got some review data of LLM performance on NVIDIA GPUs earlier this month, and the position of Navi31 is quite clear. |
A new 7900xtx is slightly more expensive, and performing slightly worse than a used 3090. With performance optimizations it would be competitive, but right now the 3090 has the advantage. It's a bit of a chicken and egg problem to be honest. |
What happened?
Turning on flash attention degrades the performance when used under ROCM (at least it does with a 7900 xtx). Using batched bench, the degradation is quite minor at a batchsize of 1.
prompt processing: 461 -> 434
token generation: 24.26 -> 23.84
However, when running multiple batches of requests at the same time, the effect is MUCH more pronounced. Especially with batch sizes of 16 the difference is massive:
prompt processing: 678 -> 375
token generation: 169.65 -> 86.87
Flash Attention is needed to be able to use quantization for the KV-cache, but the performance hit is drastic. Can this be fixed?
Name and Version
build: 4123 (2eb76b2) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
What operating system are you seeing the problem on?
Linux
Relevant log output
No response
The text was updated successfully, but these errors were encountered: