Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Flash Attention performs worse under ROCM #10439

Open
Mushoz opened this issue Nov 20, 2024 · 18 comments
Open

Bug: Flash Attention performs worse under ROCM #10439

Mushoz opened this issue Nov 20, 2024 · 18 comments
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@Mushoz
Copy link

Mushoz commented Nov 20, 2024

What happened?

Turning on flash attention degrades the performance when used under ROCM (at least it does with a 7900 xtx). Using batched bench, the degradation is quite minor at a batchsize of 1.

prompt processing: 461 -> 434
token generation: 24.26 -> 23.84

However, when running multiple batches of requests at the same time, the effect is MUCH more pronounced. Especially with batch sizes of 16 the difference is massive:

prompt processing: 678 -> 375
token generation: 169.65 -> 86.87

Flash Attention is needed to be able to use quantization for the KV-cache, but the performance hit is drastic. Can this be fixed?

Name and Version

build: 4123 (2eb76b2) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

@Mushoz Mushoz added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels Nov 20, 2024
@JohannesGaessler
Copy link
Collaborator

It's a known issue and cause by the HIP port of the CUDA FlashAttention kernel for large batch sizes having extremely poor performance (so the kernel optimized for small batch sizes is used instead). With the current code this issue cannot be fixed.

@Mushoz
Copy link
Author

Mushoz commented Nov 20, 2024

Are there any plans to rewrite that code to be optimized for ROCM instead of a CUDA port?

@JohannesGaessler
Copy link
Collaborator

No. There currently isn't a llama.cpp/GGML dev working specifically on AMD performance or even support. I am writing a lot of CUDA code but the extent of effort that I am willing to invest is make sure that the HIP port doesn't break and determining the comparatively best code paths for AMD.

@Mushoz
Copy link
Author

Mushoz commented Nov 21, 2024

Why does the performance also regress when FA is enabled when not using batching? That's also a bit weird given the fact it's optimized for small batch sizes.

Large batch sizes hasn't really been that important in the past. Hobbyists usually do not use parallelism as they have single-user use cases most of the times. But with the introduction of speculative decoding which will hopefully land in Llama server in the future as well, performance for larger batch sizes will become important even for single user use cases.

What will it take for optimizations for AMD to be considered? Will it take someone to join the project to develop and maintain these optimizations? Or might a shift towards more optimization for AMD also be made when (if?) AMD starts to become more popular among the end users?

@JohannesGaessler
Copy link
Collaborator

Why does the performance also regress when FA is enabled when not using batching? That's also a bit weird given the fact it's optimized for small batch sizes.

The code is optimized for small batch sizes and NVIDIA GPUs, if you use HIP to translate it for AMD you still pay a performance penalty vs. using an external BLAS library optimized for AMD.

What will it take for optimizations for AMD to be considered? Will it take someone to join the project to develop and maintain these optimizations? Or might a shift towards more optimization for AMD also be made when (if?) AMD starts to become more popular among the end users?

My personal goal with llama.cpp is to reduce the cost at which the largest models can be run at reasonable speeds. As of right now I think there simply isn't any AMD hardware that would be worth buying over second-hand NVIDIA hardware (RTX 3090/P40). I have seen some second-hand AMD GPUs such as the Mi 60 at very competitive prices on ebay but unfortunately the supply seems to be extremely limited. If AMD were to release a consumer GPU with a high VRAM capacity at a sufficiently low price I would start optimizing performance for that GPU (even though the AMD dev tools are worse or nonexistent).

If a new dev were to join the project with an interest in improving AMD performance I would be happy to assist them.

@Mushoz
Copy link
Author

Mushoz commented Nov 21, 2024

Understandable, thanks for the detailed explanation! Hoping to see other devs join the project to optimize AMD performance then :)

By the way, do you think (second hand) 7900 xtx could potentially be cost competitive to second hand 3090 and 4090 GPUs? The memory bandwidth is very similar between those GPUs.

@JohannesGaessler
Copy link
Collaborator

It has become better compared to 1-2 years ago but at least in my region (Germany) there currently don't really seem to be good second-hand offers for RX 7900 XTX cards.

@hjc4869
Copy link

hjc4869 commented Nov 22, 2024

You may try my forked branch that enables rocWMMA on top of the current CUDA WMMA flash attention implementation, and I'm actively rebasing the latest master branch for my own usage: https://github.com/hjc4869/llama.cpp

From my own testing it improves performance by quite a bit on RDNA3 with higher batch size, though still not optimal comparing to equivalent NVIDIA GPUs.

Flash attention (master branch)

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.762 290.58 9.131 14.02 10.893 58.75
512 128 8 5120 24.958 164.12 28.862 35.48 53.820 95.13
512 128 16 10240 76.425 107.19 85.347 24.00 161.771 63.30

Flash attention (WMMA patched)

make GGML_HIPBLAS=1 GGML_CUDA_FA_ALL_QUANTS=1 AMDGPU_TARGETS=gfx1100,gfx1101 -j64

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -fa -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.601 319.81 9.180 13.94 10.781 59.36
512 128 8 5120 15.654 261.65 29.136 35.15 44.790 114.31
512 128 16 10240 37.842 216.48 38.110 53.74 75.952 134.82

Flash attention off

./llama-batched-bench -ngl 999 -m ~/models/qwen2.5-72b-iq4.gguf -npl 1,8,16 -npp 512 -ntg 128 -c 10240

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 999, n_threads = 32, n_threads_batch = 32

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.584 323.17 8.691 14.73 10.276 62.28
512 128 8 5120 15.569 263.09 26.352 38.86 41.921 122.13
512 128 16 10240 37.198 220.23 50.002 40.96 87.200 117.43

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Wow! I am seeing very good results here. Some observations:

  1. With FA turned off, performance between your branch and master is identical. So that's good, it did not introduce any regressions in the FA turned off case.
  2. With the non-batched bench, the regression observed when turning on FA for prompt processing is removed. Turning on FA gives about the same tokens/sec during PP as when it's turned off. In the master branch, a big performance regression is noticed.
  3. With the non-batched bench, the regression when turning on FA for the tokens/sec during generation is still observed, but it has not gotten worse. So identical performance to the master branch there.
  4. For the batched bench, for batch sizes of 1, 2, 4 and 8, identical performance is observed with FA turned on compared to the master branch. Including the drop-off in performance when going from batch size 4 to batch size 8. So there's probably same performance to be gained versus the non-FA case here. But at least this branch did not regress compared to master.
  5. For batch sizes 16 and 32, turning on FA massively boosts performance, whereas on master turning on FA would actually REDUCE performance. Very good improvement there!
  6. All in all, I did not experience a single downside compared to master branch. Very good job!

A question: I saw that your branch also introduced the option to compile with GGML_CUDA_FA_ALL_QUANTS. What does this do? I did not enable this, but is it worth enabling?

Master branch:

Single Bench

Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no

model size params backend ngl fa test t/s
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 pp512 744.82 ± 2.17
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 tg128 27.19 ± 0.06
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 pp512 639.80 ± 0.63
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 tg128 25.64 ± 0.01

build: a5e4759 (4150)

Batched bench FA off

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.550 697.74 4.969 25.76 5.519 92.77
384 128 2 1024 1.092 703.51 6.267 40.85 7.359 139.16
384 128 4 2048 2.200 698.23 9.225 55.50 11.425 179.26
384 128 8 4096 4.813 638.29 16.997 60.25 21.810 187.81
384 128 16 8192 11.088 554.12 17.008 120.41 28.096 291.57

ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error
ROCm error: out of memory

Batched bench FA on

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.635 604.85 5.115 25.02 5.750 89.04
384 128 2 1024 1.276 601.72 6.224 41.13 7.501 136.52
384 128 4 2048 2.819 544.78 9.554 53.59 12.373 165.52
384 128 8 4096 7.176 428.08 19.747 51.86 26.923 152.14
384 128 16 8192 20.508 299.59 40.661 50.37 61.168 133.93
384 128 32 16384 66.030 186.10 84.869 48.26 150.900 108.58

Your branch

Single bench

Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no

model size params backend ngl fa test t/s
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 pp512 744.91 ± 1.40
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 0 tg128 27.14 ± 0.07
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 pp512 739.81 ± 1.30
qwen2 ?B Q4_K - Small 17.49 GiB 32.76 B ROCm 99 1 tg128 25.60 ± 0.02

build: 8739ed4 (4154)

Batched bench FA off

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 0, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.557 689.97 4.983 25.69 5.539 92.43
384 128 2 1024 1.091 703.82 6.615 38.70 7.706 132.88
384 128 4 2048 2.190 701.36 9.243 55.39 11.433 179.13
384 128 8 4096 4.795 640.71 17.012 60.19 21.807 187.83
384 128 16 8192 11.054 555.84 17.021 120.33 28.074 291.80

ggml/src/ggml-cuda/ggml-cuda.cu:70: ROCm error
ROCm error: out of memory
current device: 0, in function alloc at ggml/src/ggml-cuda/ggml-cuda.cu:275
ggml_cuda_device_malloc(&ptr, look_ahead_size, device)

Batched bench FA on

main: n_kv_max = 16384, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
384 128 1 512 0.579 663.20 5.114 25.03 5.693 89.93
384 128 2 1024 1.100 698.49 6.243 41.01 7.342 139.47
384 128 4 2048 2.170 707.78 9.601 53.33 11.771 173.99
384 128 8 4096 4.749 646.81 19.833 51.63 24.582 166.62
384 128 16 8192 11.070 555.00 14.309 143.13 25.379 322.78
384 128 32 16384 28.232 435.25 31.475 130.14 59.707 274.41

@hjc4869
Copy link

hjc4869 commented Nov 22, 2024

GGML_CUDA_FA_ALL_QUANTS is not related to the issue. It's to compile FA kernel for all the KV cache quantization combinations. For example if you want -ctk q4_0 in combination with -ctv q8_0 (different quantization type), or those rarely used ones like q4_1, q5_0, q5_1.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Thanks for your explanation. Do you have any intention of trying to upstream these FA changes?

@hjc4869
Copy link

hjc4869 commented Nov 22, 2024

Nope, actually this has been done before and @JohannesGaessler had some explanation here: #7011

Although the current patch did improve performance to some degree, it is still not optimized for AMD hardware. I think we should try to push AMD to officially take the responsibility to maintain ROCm/HIP support in llama.cpp just like MUSA/CANN, rather than relying on the community to port code optimized for NVIDIA. But these days they seem to be primarily focusing on contributing to the vLLM project rather than things like llama.cpp that is easy to use on client devices/Radeon cards, unfortunately.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

Didn't that comment in that PR mention they would be open to merging improvements, as long as someone would commit to maintaining that code? Not sure if that would be an option for you.

Is there some way we can reach out to AMD and convince them to commit to performance improvements and maintaining them for llamacpp?

@hjc4869
Copy link

hjc4869 commented Nov 22, 2024

I'm quite new in this area and not a GPU programming / machine learning expert. I do have some contacts in AMD but they're quite busy these days and diverting resources to a new project would probably need to get attention from executives, of course that's not what I could do.

AMD themselves are actively using llama.cpp in their marketing materials, and Xilinx also have a forked branch internally for their AIE NPUs. But getting actual engineering resources to work on this project wouldn't be so easy.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

That's unfortunate to hear. I really think the 7900xtx could be competitive both on cost and performance. It just needs the support to get there.

The situation has already improved massively compared to 1-2 years ago though. So perhaps in the future we'll get there.

@ccbadd
Copy link

ccbadd commented Nov 22, 2024

It has become better compared to 1-2 years ago but at least in my region (Germany) there currently don't really seem to be good second-hand offers for RX 7900 XTX cards.>

A new 7900xtx is pretty much the same price (or reasonably close) as a used 3090 so doesn't that mean AMD cards are competitive? That used NV market is going to dry up at some point and NV does not care to make affordable cards that have enough vram for llms.

@hjc4869
Copy link

hjc4869 commented Nov 22, 2024

doesn't that mean AMD cards are competitive?

The problem is that 7900 XTX doesn't have the performance of RTX 3090 in LLMs, either on paper or in real world testing.

image

image

I got some review data of LLM performance on NVIDIA GPUs earlier this month, and the position of Navi31 is quite clear.

@Mushoz
Copy link
Author

Mushoz commented Nov 22, 2024

A new 7900xtx is slightly more expensive, and performing slightly worse than a used 3090. With performance optimizations it would be competitive, but right now the 3090 has the advantage. It's a bit of a chicken and egg problem to be honest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

No branches or pull requests

4 participants