Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[benchmark] Add fused_moe_triton benchmark and tuning tools #2225

Merged
merged 37 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8c10a24
fix a bug in v1_embeeding_request
BBuf Nov 12, 2024
f2d6418
Merge branch 'sgl-project:main' into main
BBuf Nov 12, 2024
9e3bb3d
fix test_embedding_models prompt length too long's bug
BBuf Nov 12, 2024
97c029c
fix format
BBuf Nov 12, 2024
899b7f7
Merge branch 'sgl-project:main' into main
BBuf Nov 13, 2024
2afab09
Merge branch 'sgl-project:main' into main
BBuf Nov 15, 2024
6e6aec6
fix a small typo in docs
BBuf Nov 15, 2024
a5f4383
Merge branch 'main' into main
BBuf Nov 15, 2024
4865f98
format backend.md
BBuf Nov 15, 2024
2fec3fe
Apply suggestions from code review
merrymercy Nov 15, 2024
b613bb6
Update docs/references/hyperparameter_tuning.md
merrymercy Nov 15, 2024
b8dbac7
Merge branch 'sgl-project:main' into main
BBuf Nov 16, 2024
28be6f9
Merge branch 'sgl-project:main' into main
BBuf Nov 25, 2024
39b3309
add tuning fused configs for qwen2 57b and mixtral 8x7b
BBuf Nov 25, 2024
ff9cd12
revert typo
BBuf Nov 25, 2024
4ce5d4c
add fused_moe_triton benchmark and tuning tools
BBuf Nov 27, 2024
aa4dda0
Merge branch 'sgl-project:main' into main
BBuf Nov 27, 2024
1240113
delete useless comment
BBuf Nov 27, 2024
9679714
Merge branch 'main' of github.com:BBuf/sglang
BBuf Nov 27, 2024
e4d502e
A more generalized benchmark implementation
BBuf Nov 28, 2024
1b68d31
Delete useless file
BBuf Nov 28, 2024
e40084b
refine benchmark_vllm_vs_sglang_fused_moe_triton.py output name
BBuf Nov 28, 2024
6c42078
lint
BBuf Nov 28, 2024
adfc59f
Merge pull request #1 from BBuf/lint
BBuf Nov 28, 2024
cebf16e
Merge branch 'main' into main
BBuf Nov 28, 2024
2946943
Merge branch 'sgl-project:main' into main
BBuf Nov 28, 2024
12f55fd
fix chunked prefill size defualt value in GTX 4090
BBuf Nov 28, 2024
4e419ac
Merge pull request #2 from BBuf/fix_default_chunked_prefill_size_4090
BBuf Nov 28, 2024
7a3e8e2
format
BBuf Nov 28, 2024
569f006
Merge pull request #3 from BBuf/fix_chunked_prefill_size_bug
BBuf Nov 28, 2024
05bccf0
more friendly commet
BBuf Nov 29, 2024
d80745b
fix comment
BBuf Nov 29, 2024
a839d96
Lint fixed
BBuf Nov 29, 2024
9bf7bd4
fix init_type comment
BBuf Nov 29, 2024
31296f1
Merge branch 'main' into main
HaiShaw Nov 29, 2024
a70cec9
Merge branch 'main' into main
HaiShaw Nov 29, 2024
238eb5c
Merge branch 'main' into main
HaiShaw Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmark/kernels/fused_moe_triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## benchmark kernels

- `tuning_fused_moe_triton.py`: tuning the `fused_moe_triton` kernel. It's adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py , and add support for qwen2-57b tuning.

For example, to tune the `Qwen/Qwen2-57B-A14B-Instruct-FP8` model's `fused_moe_triton` fp8_w8a8 kernel with TP4, run:

```bash
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py --model Qwen/Qwen2-57B-A14B-Instruct-FP8 --tp-size 4 --dtype fp8_w8a8 --tune
```

And you can get `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json` in current directory, then you can put it in `sglang/srt/layers/fused_moe_triton/configs/` and use it in `sglang`.

- `benchmark_qwen2_57b_vllm_vs_sglang_fused_moe_triton.py`: benchmark the `Qwen/Qwen2-57B-A14B-Instruct-FP8` model's `fused_moe_triton` fp8_w8a8 kernel with vllm and sglang.

Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numbers
from typing import Optional
import torch
import triton
from torch.nn.parameter import Parameter
from torch.nn import init
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.fused_moe_triton.fused_moe import get_moe_configs as get_moe_configs_sglang
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_configs as get_moe_configs_vllm

def fused_moe_vllm_api(x, w1, w2, input_gating, topk, w1_scale, w2_scale, a1_scale, a2_scale):
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)

def fused_moe_sglang_api(x, w1, w2, input_gating, topk, w1_scale, w2_scale, a1_scale, a2_scale):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['batch_size'],
x_vals=list(range(1, 513)),
line_arg='provider',
line_vals=['vllm fused moe', 'sglang fused moe',],
line_names=["vllm fused moe", 'sglang fused moe',],
styles=[('blue', '-'), ('green', '-'),],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider):
print(f'benchmark for {batch_size}.')
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
dtype = torch.bfloat16
init_dtype = torch.float16
num_tokens = batch_size
num_experts = 64
hidden_size = 3584
shard_intermediate_size = 1280
topk = 8
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
input_gating = torch.randn(num_tokens,
num_experts,
dtype=torch.float32)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)

w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)

# warmup for fused_moe_vllm
for _ in range(10):
y = fused_moe_vllm_api(x, w1, w2, input_gating, topk, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale)
torch.cuda.synchronize()
# warmup for fused_moe_sglang
for _ in range(10):
y = fused_moe_sglang_api(x, w1, w2, input_gating, topk, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale)
torch.cuda.synchronize()


quantiles = [0.5, 0.2, 0.8]

if provider == 'vllm fused moe':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_vllm_api(x, w1, w2, input_gating, topk, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale)[0], quantiles=quantiles)
elif provider == 'sglang fused moe':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_sglang_api(x, w1, w2, input_gating, topk, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale)[0], quantiles=quantiles)
return ms, min_ms, max_ms

benchmark.run(show_plots=True, print_data=True, save_path='./configs/benchmark_ops/vllm_sglang_fused_moe/')
Loading
Loading