Skip to content

Commit

Permalink
feat: fused_moe fp8 monkey patch (#2174)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Nov 25, 2024
1 parent a866b65 commit 55842eb
Showing 1 changed file with 68 additions and 18 deletions.
86 changes: 68 additions & 18 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py

from typing import Dict, Type
from typing import Callable, Dict, Optional, Type

import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
Expand All @@ -30,8 +31,6 @@
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
Expand All @@ -47,38 +46,89 @@

def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
raise ValueError(
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
return QUANTIZATION_METHODS[quantization]


__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
def fp8_moe_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
"""Enhanced apply method for FP8 MoE."""
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts

# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
)

# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)


def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)

from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
from sglang.srt.layers.linear import UnquantizedLinearMethod

return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return None


setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
def apply_monkey_patches():
"""Apply all monkey patches in one place."""
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)


# Apply patches when module is imported
apply_monkey_patches()


__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]

0 comments on commit 55842eb

Please sign in to comment.