Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM MOE] Enable ROCM AITER Block MOE For DeepSeek R1/V3 #3788

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bcda23e
add block gemm tune amd config
Feb 8, 2025
e0dda9a
retune MI300 block gemm
BruceXcluding Feb 8, 2025
ce41796
retune triton block gemm AMD Radeon
BruceXcluding Feb 8, 2025
b2d9166
retune moe bs64 for AMD Radeon
Feb 9, 2025
806ba61
fall back to original MI300X block gemm config
Feb 9, 2025
582f892
fall back origin bs4 block gemm for AMD Radeon
Feb 9, 2025
bcbb54f
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 9, 2025
983e54c
fall back to new bs4 block gemm for AMD Radeon
Feb 9, 2025
b1fb432
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 9, 2025
d173782
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 10, 2025
8f6198f
Merge branch 'main' into main_perf
HaiShaw Feb 10, 2025
ec9595b
Merge branch 'main' into main_perf
HaiShaw Feb 10, 2025
6944f82
Merge branch 'main' into main_perf
HaiShaw Feb 11, 2025
d0961f3
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 13, 2025
1fc3c2c
optimal moe tuning
BruceXcluding Feb 13, 2025
0497d76
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 14, 2025
7f5a0fc
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 14, 2025
8c0d407
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 17, 2025
316cf24
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 18, 2025
b94c3a1
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 20, 2025
ea22665
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 22, 2025
6fdca41
add aiter block moe
Feb 22, 2025
52a1395
Clang Format
BruceXcluding Feb 22, 2025
d53bcc2
fix get env in layer.py
Feb 22, 2025
3723682
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 23, 2025
a45a8bb
enable EP MoE for AITER backend
valarLip Feb 23, 2025
0093754
do shuffle at data process time
valarLip Feb 23, 2025
70f66a7
fix is_hip runtime calling
BruceXcluding Feb 24, 2025
466149b
Clang Format
BruceXcluding Feb 24, 2025
0b4af56
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 24, 2025
2041682
fix is_hip calling in dsv.py
BruceXcluding Feb 24, 2025
629c858
merge 'main' into main
BruceXcluding Feb 25, 2025
83f2589
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 25, 2025
e51ba86
Merge branch 'sgl-project:main' into main_perf
BruceXcluding Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 201 additions & 15 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Callable, List, Optional, Tuple

import torch
Expand All @@ -17,15 +18,24 @@
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_w8a8_block_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip, set_weight_attrs

is_hip_ = is_hip()

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -115,6 +125,8 @@ def __init__(
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
num_shared_experts: Optional[int] = 0,
routed_scaling_factor: Optional[float] = 1.0,
):
super().__init__()

Expand All @@ -127,6 +139,7 @@ def __init__(
self.tp_rank = get_tensor_model_parallel_rank()

self.num_experts = num_experts
self.num_shared_experts = num_shared_experts
assert self.num_experts % self.tp_size == 0
self.num_experts_per_partition = self.num_experts // self.tp_size
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
Expand Down Expand Up @@ -155,6 +168,18 @@ def __init__(
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
if is_hip_ and os.getenv("SGLANG_ROCM_AITER_BLOCK_MOE") == "1":
self.routed_scaling_factor = routed_scaling_factor

self.expert_mask = torch.zeros(
(self.num_experts + self.num_shared_experts + 1),
device="cuda",
dtype=torch.int,
)
self.expert_mask[self.start_expert_id : self.end_expert_id + 1] = 1
self.expert_mask[self.num_experts : -1] = 1

self.num_experts_per_partition += self.num_shared_experts

self.quant_method.create_weights(
layer=self,
Expand All @@ -171,6 +196,21 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert self.activation == "silu"

if is_hip_ and os.getenv("SGLANG_ROCM_AITER_BLOCK_MOE") == "1":
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
)
return final_hidden_states

if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
Expand Down Expand Up @@ -317,6 +357,7 @@ def make_expert_params_mapping(
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_shared_experts: Optional[int] = 0,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
Expand All @@ -336,6 +377,27 @@ def make_expert_params_mapping(
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
] + [
(
(
"experts.w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else "experts.w2_"
),
(
f"shared_experts.{expert_id}.{weight_name}."
if num_shared_experts >= 2
else f"shared_experts.{weight_name}."
),
-num_shared_experts + expert_id,
shard_id,
)
for expert_id in range(num_shared_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

def weight_loader(
Expand All @@ -346,9 +408,13 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
if expert_id >= 0 and (
expert_id < self.start_expert_id or expert_id > self.end_expert_id
):
return
expert_id = expert_id - self.start_expert_id
# expert_id < 0 means shared expert
if expert_id >= 0:
expert_id = expert_id - self.start_expert_id

if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
Expand Down Expand Up @@ -394,6 +460,13 @@ def _load_fp8_scale(
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale_inv" in weight_name:
# elif getattr(param, "quant_method", None) == FusedMoeWeightScaleSupported.BLOCK.value:
if shard_id in ("w1", "w3"):
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
else:
param_data[expert_id] = loaded_weight
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
Expand Down Expand Up @@ -498,6 +571,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None

def create_weights(
self,
Expand Down Expand Up @@ -538,21 +612,53 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts_per_partition,
2,
((intermediate_size + block_n - 1) // block_n),
((hidden_size + block_k - 1) // block_k),
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)

w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update({"quant_method": "tensor"})
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
Expand Down Expand Up @@ -629,6 +735,41 @@ def process_weights_after_loading(self, layer: Module) -> None:
torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False,
)
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if os.getenv("SGLANG_ROCM_AITER_BLOCK_MOE") == "1":
import aiter
from aiter.ops.shuffle import shuffle_weight

layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
return

def apply(
Expand All @@ -643,4 +784,49 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
if is_hip_ and os.getenv("SGLANG_ROCM_AITER_BLOCK_MOE") == "1":
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=layer.correction_bias,
)

# TODO these can be removed when "select_experts" is inplaced op
token = x.shape[0]
layer.ns_topk_weights[:token] = topk_weights * layer.routed_scaling_factor
layer.ns_topk_ids[:token] = topk_ids
topk_ids = layer.total_topk_ids[:token]
topk_weights = layer.total_topk_weights[:token]

return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
expert_mask=layer.expert_mask,
)
else:
raise NotImplementedErrors
Loading
Loading