Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[benchmark] Add fused_moe_triton benchmark and tuning tools #2225

Merged
merged 37 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8c10a24
fix a bug in v1_embeeding_request
BBuf Nov 12, 2024
f2d6418
Merge branch 'sgl-project:main' into main
BBuf Nov 12, 2024
9e3bb3d
fix test_embedding_models prompt length too long's bug
BBuf Nov 12, 2024
97c029c
fix format
BBuf Nov 12, 2024
899b7f7
Merge branch 'sgl-project:main' into main
BBuf Nov 13, 2024
2afab09
Merge branch 'sgl-project:main' into main
BBuf Nov 15, 2024
6e6aec6
fix a small typo in docs
BBuf Nov 15, 2024
a5f4383
Merge branch 'main' into main
BBuf Nov 15, 2024
4865f98
format backend.md
BBuf Nov 15, 2024
2fec3fe
Apply suggestions from code review
merrymercy Nov 15, 2024
b613bb6
Update docs/references/hyperparameter_tuning.md
merrymercy Nov 15, 2024
b8dbac7
Merge branch 'sgl-project:main' into main
BBuf Nov 16, 2024
28be6f9
Merge branch 'sgl-project:main' into main
BBuf Nov 25, 2024
39b3309
add tuning fused configs for qwen2 57b and mixtral 8x7b
BBuf Nov 25, 2024
ff9cd12
revert typo
BBuf Nov 25, 2024
4ce5d4c
add fused_moe_triton benchmark and tuning tools
BBuf Nov 27, 2024
aa4dda0
Merge branch 'sgl-project:main' into main
BBuf Nov 27, 2024
1240113
delete useless comment
BBuf Nov 27, 2024
9679714
Merge branch 'main' of github.com:BBuf/sglang
BBuf Nov 27, 2024
e4d502e
A more generalized benchmark implementation
BBuf Nov 28, 2024
1b68d31
Delete useless file
BBuf Nov 28, 2024
e40084b
refine benchmark_vllm_vs_sglang_fused_moe_triton.py output name
BBuf Nov 28, 2024
6c42078
lint
BBuf Nov 28, 2024
adfc59f
Merge pull request #1 from BBuf/lint
BBuf Nov 28, 2024
cebf16e
Merge branch 'main' into main
BBuf Nov 28, 2024
2946943
Merge branch 'sgl-project:main' into main
BBuf Nov 28, 2024
12f55fd
fix chunked prefill size defualt value in GTX 4090
BBuf Nov 28, 2024
4e419ac
Merge pull request #2 from BBuf/fix_default_chunked_prefill_size_4090
BBuf Nov 28, 2024
7a3e8e2
format
BBuf Nov 28, 2024
569f006
Merge pull request #3 from BBuf/fix_chunked_prefill_size_bug
BBuf Nov 28, 2024
05bccf0
more friendly commet
BBuf Nov 29, 2024
d80745b
fix comment
BBuf Nov 29, 2024
a839d96
Lint fixed
BBuf Nov 29, 2024
9bf7bd4
fix init_type comment
BBuf Nov 29, 2024
31296f1
Merge branch 'main' into main
HaiShaw Nov 29, 2024
a70cec9
Merge branch 'main' into main
HaiShaw Nov 29, 2024
238eb5c
Merge branch 'main' into main
HaiShaw Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions benchmark/kernels/fused_moe_triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
## Benchmark Kernels

This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.

### Tuning Tool

- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.

Example usage:
```bash
# Tune Qwen2-57B with FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune

# Tune Mixtral-8x7B with default settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune
```

After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`.

### Performance Comparison Tool

- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.

Example usage:
```bash
# Compare with default settings (Mixtral model)
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py

# Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct-FP8 \
--use-fp8

# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--tp-size 4
```

The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import argparse
import numbers
from typing import Optional

import torch
import triton
from torch.nn import init
from torch.nn.parameter import Parameter
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_moe_configs as get_moe_configs_vllm,
)
from vllm.utils import FlexibleArgumentParser

from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.fused_moe_triton.fused_moe import (
get_moe_configs as get_moe_configs_sglang,
)


def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name)

if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral, Grok1, etc.
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size

return {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
}


def fused_moe_vllm_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)


def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 513)),
line_arg="provider",
line_vals=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
line_names=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark for batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)

num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]

x = torch.randn(num_tokens, hidden_size, dtype=dtype)

if use_fp8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None

input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)

HaiShaw marked this conversation as resolved.
Show resolved Hide resolved
# Warmup
api_func = (
fused_moe_vllm_api
if provider == "vllm_fused_moe_triton"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.cuda.synchronize()

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms


def main():
parser = FlexibleArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
)
args = parser.parse_args()

model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8=args.use_fp8,
)


if __name__ == "__main__":
main()
Loading
Loading