Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split ReplicatedLinear used in MLA prefill computing along hidden_states[0] to save duplicated computing on all devices #3688

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

ZJLi2013
Copy link

@ZJLi2013 ZJLi2013 commented Feb 19, 2025

Motivation

in MLA, there are a few ReplicatedLinear ops, .e.g q_a_proj, kv_a_proj_with_mqa, meaning the same hidden_states tensor are computing on all devices, which can be reduce by spliting the hidden_states by tp_size along batch_size * seqlen (a.k.a total_num_tokens) dim, to save duplicated gemm computing. currently it's only useful in prefill computing.

Modifications

  1. replace ReplicatedLinear with dp_linear in deepseek-v2.py, which split input hidden_states along total_num_tokens dim and do all_gather at last step
  2. add test_dp_linear.py for prefill/decoding benchmark

MI308 Benchmark Results

python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 1  --model /data/DeepSeek-V3/ --tp 8 --trust-remote-code

before 5122.06 toks/s, after 5518.23 toks/s

Checklist

@ZJLi2013
Copy link
Author

serving bench results update:

baseline without use_dp_linear

# prefill
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                32
Successful requests:                     200
Benchmark duration (s):                  364.37
Total input tokens:                      640000
Total generated tokens:                  200
Total generated tokens (retokenized):    197
Request throughput (req/s):              0.55
Input token throughput (tok/s):          1756.46
Output token throughput (tok/s):         0.55
Total token throughput (tok/s):          1757.01
Concurrency:                             29.91
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   54487.73
Median E2E Latency (ms):                 38634.45
---------------Time to First Token----------------
Mean TTFT (ms):                          53974.92
Median TTFT (ms):                        38597.53
P99 TTFT (ms):                           128695.85
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00
Median TPOT (ms):                        0.00
P99 TPOT (ms):                           0.00
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P99 ITL (ms):                            0.00
==================================================
# e2e decoding
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                32
Successful requests:                     200
Benchmark duration (s):                  728.66
Total input tokens:                      640000
Total generated tokens:                  100000
Total generated tokens (retokenized):    99595
Request throughput (req/s):              0.27
Input token throughput (tok/s):          878.32
Output token throughput (tok/s):         137.24
Total token throughput (tok/s):          1015.56
Concurrency:                             30.29
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   110342.08
Median E2E Latency (ms):                 111139.55
---------------Time to First Token----------------
Mean TTFT (ms):                          35991.87
Median TTFT (ms):                        37487.63
P99 TTFT (ms):                           62302.38
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          149.00
Median TPOT (ms):                        145.35
P99 TPOT (ms):                           210.27
---------------Inter-token Latency----------------
Mean ITL (ms):                           149.01
Median ITL (ms):                         104.52
P99 ITL (ms):                            168.93
==================================================

use_dp_linear

# prefill 
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                3
Successful requests:                     200
Benchmark duration (s):                  254.22
Total input tokens:                      640000
Total generated tokens:                  200
Total generated tokens (retokenized):    197
Request throughput (req/s):              0.79
Input token throughput (tok/s):          2517.47
Output token throughput (tok/s):         0.79
Total token throughput (tok/s):          2518.26
Concurrency:                             2.99
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   3800.61
Median E2E Latency (ms):                 3420.09
---------------Time to First Token----------------
Mean TTFT (ms):                          3748.08
Median TTFT (ms):                        3418.47
P99 TTFT (ms):                           7227.73
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00
Median TPOT (ms):                        0.00
P99 TPOT (ms):                           0.00
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P99 ITL (ms):                            0.00
==================================================

# e2e decoding
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                32
Successful requests:                     200
Benchmark duration (s):                  690.73
Total input tokens:                      640000
Total generated tokens:                  100000
Total generated tokens (retokenized):    99611
Request throughput (req/s):              0.29
Input token throughput (tok/s):          926.55
Output token throughput (tok/s):         144.77
Total token throughput (tok/s):          1071.33
Concurrency:                             30.53
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   105443.90
Median E2E Latency (ms):                 105185.21
---------------Time to First Token----------------
Mean TTFT (ms):                          32991.32
Median TTFT (ms):                        34231.23
P99 TTFT (ms):                           61317.65
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          145.20
Median TPOT (ms):                        143.91
P99 TPOT (ms):                           207.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           145.20
Median ITL (ms):                         104.60
P99 ITL (ms):                            123.57
==================================================

@zhyncs
Copy link
Member

zhyncs commented Feb 20, 2025

Hi, please let me know when it's ready for review. Thanks!

@zhyncs zhyncs self-assigned this Feb 20, 2025
@ZJLi2013 ZJLi2013 marked this pull request as ready for review February 25, 2025 01:40
@ZJLi2013
Copy link
Author

Hi, please let me know when it's ready for review. Thanks!

hi, @zhyncs many thanks for review. bench test covered just few isl/osl/num_promts on h20/mi30x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants