Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking MLA #700

Open
YLGH opened this issue Dec 25, 2024 · 7 comments
Open

Benchmarking MLA #700

YLGH opened this issue Dec 25, 2024 · 7 comments
Assignees

Comments

@YLGH
Copy link

YLGH commented Dec 25, 2024

Hi, I was trying to benchmark the MLA attention decoding kernel with the following script (adding CUDA timing events around BatchDecodeMlaWithPagedKVCacheWrapper::run in test_mla_decode_kernel). https://gist.github.com/YLGH/e8ebd7577d12f6c7963bcbae95e3b781

However, I'm seeing some very low numbers, such as an effective memory throughput of 50 GiB/s for batch_size=32, kv_len=16k, page_size=16. I feel like I'm mis-using it somehow, but I couldn't find any examples of BatchDecodeMlaWithPagedKVCacheWrapper being used in not test code so not sure if I am doing it correctly. I couldn't get the nvbench suite compiled

@yzh119 yzh119 self-assigned this Dec 25, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Dec 26, 2024

I ran both your script and the nvbench scripts and it seems the bandwidth utilization is indeed very low. (btw, can you show me the error message of compiling nvbench suite?).
I'll checking the ncu profiling results and see how we can fix this.

@tsu-bin do you have any ideas?

@tsu-bin
Copy link
Contributor

tsu-bin commented Dec 26, 2024

hi, please refer to this comment:
#551 (comment)

The bandwidth utilization from nvbench is not real, since the gmem traffic needed is set manually in the benchmark function:

state.add_global_memory_reads<uint8_t>(
      vec_bytes(q_nope) + vec_bytes(q_pe) + vec_bytes(ckv_data) + vec_bytes(kpe_data) +
          vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
      "Read");
  state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");

The benchmark from NCU is real, the utilization value is about 70%.

Please note that, for large qo-head number, we have to read kv-cache data multiple times, which is inevitable for the current scheduling design.

BTW, not sure when we can start to work on the CuTe refactor of the mla decode kernel.

@YLGH
Copy link
Author

YLGH commented Dec 26, 2024

Hey. @tsu-bin @yzh119, thanks for looking and the reply!

@tsu-bin, I was profiling just based on a PyTorch level by dividing size of my latent cache and dividing by the total cpu time taken as a first level approximation to compare against the roofline.

I also tried with q_shape = (32, 16 (q_o heads), 512) as well, but only see slightly better results.

@yzh119 for my installation issue I was seeing: Could NOT find IBVerbs (missing: IBVERBS_LIBRARIES), similar to #674, but after pulling main still had the same issue.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 27, 2024

Please note that, for large qo-head number, we have to read kv-cache data multiple times, which is inevitable for the current scheduling design.

Yes we should improve the schedule for MLAs.
Can we schedule a meeting next week to discuss the CuTE refactor? You can ping me on slack or email me ( [email protected]).

@tsu-bin
Copy link
Contributor

tsu-bin commented Dec 27, 2024

hi, I think the multiple loads of kv-cache from gmem is inevitable, just like gemm, tiles need to be load multiple times by different CTAs. All we can do is let one CTA processes as many qo-heads as possible, besides that maybe L2 cache hit-rate improvement can alleviated to some extent.

I will pin you on slack.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 3, 2025

@tsu-bin @YLGH I checked the implementation again and I realized what we actually need is a fused-rope MQA kernel.

We can reuse our current GQA prefill attention template for MQA (MQA is a special form of GQA where num_kv_heads=1, the group size could be as large as 128 in MLA, in that sense, we should not use the decode codebase as it's only designed for small group size, the prefill attention template fuses the query length and heads dimension which increases the operational intensity per CTA in the kernel(see appendix A of paper)), but there are two changes we need to make to the kernel:

  1. Enable different head dimension for query/key and value, more specifically, we need qk_head_dim=192, and v_head_dim=128.
  2. Enable partial apply RoPE (only on first 64-dimension).

@tsu-bin
Copy link
Contributor

tsu-bin commented Jan 6, 2025

hi @yzh119

the prefill attention template fuses the query length and heads dimension which increases the operational intensity per CTA in the kernel(see appendix A of paper))

Yes, this is exactly what I want to do, the first matmul is 'qo-heads-chunk * (kpe-dim + ckv-dim) * kv-sequence-chunk' and the 'qo-heads-chunk * (kpe-dim + ckv-dim)' data will be loaded to smem instead of rmem as the current MLA decode kernel did.
I think a dedicated MLA kernel is still preferable, because: 1) Yes, you can load kpe-cache and ckv-cache into smem so they concatenate into one continuous kv-cache tensor, but ckv-cache will be used both as Key an Value data, so the adaptation of prefill kernel may be not very easy. 2) please note that the ckv-dim is rather large (512 for deepseek-v2 and v3), so the matmul has a large reduction axis, which I think may need a different scheduling design rather than the normal prefill kernel.

we need qk_head_dim=192, and v_head_dim=128.

Please note that for the Mat-Absorb version of MLA decode kerel, qk_head_dim = 512+64 and v_head_dim=512 for deepseek-v2. So the smem can't accommodate all qo-heads data, we have to load qo-heads data multiple times.

I'm still try hard to implement a new CuTe (tensor-core) version of MLA decode kernel, still need some time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants