diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 86085284d15..b0f7150cdb4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -22,7 +22,7 @@ from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import get_config, get_context_length -from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.layers.quantization import QUANTIZATION_METHODS, get_quantization_config from sglang.srt.utils import get_bool_env_var, is_hip logger = logging.getLogger(__name__) @@ -218,7 +218,7 @@ def _parse_quant_hf_config(self): # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] + supported_quantization = QUANTIZATION_METHODS rocm_supported_quantization = [ "awq", "gptq", @@ -253,7 +253,8 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization ) diff --git a/python/sglang/srt/layers/__init__.py b/python/sglang/srt/layers/__init__.py new file mode 100644 index 00000000000..5337982ddcb --- /dev/null +++ b/python/sglang/srt/layers/__init__.py @@ -0,0 +1,18 @@ +def patch_vllm_linear_base_isinstance(): + import builtins + + from vllm.model_executor.layers.linear import LinearBase + + from sglang.srt.layers.linear import LinearBase as PatchedLinearBase + + original_isinstance = builtins.isinstance + + def patched_isinstance(obj, classinfo): + if classinfo is LinearBase: + return original_isinstance(obj, PatchedLinearBase) + return original_isinstance(obj, classinfo) + + builtins.isinstance = patched_isinstance + + +patch_vllm_linear_base_isinstance() diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 928643b70c2..421cc49fad1 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,64 +1,88 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, List, Optional, Type import torch -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, AWQMoEMethod, ) -from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig, -) -from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig -from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config -from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig -from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config -from sglang.srt.layers.quantization.fp8 import Fp8Config -from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config -from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config - -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, - "fp8": Fp8Config, - "blockwise_int8": BlockInt8Config, - "fbgemm_fp8": FBGEMMFp8Config, - "marlin": MarlinConfig, - "modelopt": ModelOptFp8Config, - "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, - "gptq_marlin": GPTQMarlinConfig, - "awq_marlin": AWQMarlinConfig, - "gptq": GPTQConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, - "qqq": QQQConfig, - "experts_int8": ExpertsInt8Config, - "w8a8_int8": W8A8Int8Config, + +QUANTIZATION_METHODS: List[str] = { + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "blockwise_int8", + "fbgemm_fp8", + "marlin", + "modelopt", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "experts_int8", + "w8a8_int8", } def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: - raise ValueError( - f"Invalid quantization method: {quantization}. " - f"Available methods: {list(QUANTIZATION_METHODS.keys())}" - ) - return QUANTIZATION_METHODS[quantization] + raise ValueError(f"Invalid quantization method: {quantization}.") + + from vllm.model_executor.layers.quantization.aqlm import AQLMConfig + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig + from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config + from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config + from vllm.model_executor.layers.quantization.gguf import GGUFConfig + from vllm.model_executor.layers.quantization.gptq import GPTQConfig + from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config, + ) + from vllm.model_executor.layers.quantization.marlin import MarlinConfig + from vllm.model_executor.layers.quantization.qqq import QQQConfig + from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig + + from .blockwise_int8 import BlockInt8Config + from .fp8 import Fp8Config + from .modelopt_quant import ModelOptFp8Config + from .w8a8_int8 import W8A8Int8Config + + method_to_config: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "blockwise_int8": BlockInt8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "marlin": MarlinConfig, + "modelopt": ModelOptFp8Config, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "experts_int8": ExpertsInt8Config, + "w8a8_int8": W8A8Int8Config, + } + + return method_to_config[quantization] def gptq_get_quant_method(self, layer, prefix): @@ -133,23 +157,6 @@ def awq_moe_method_apply( ) -def patch_vllm_linear_base_isinstance(): - import builtins - - from vllm.model_executor.layers.linear import LinearBase - - from sglang.srt.layers.linear import LinearBase as PatchedLinearBase - - original_isinstance = builtins.isinstance - - def patched_isinstance(obj, classinfo): - if classinfo is LinearBase: - return original_isinstance(obj, PatchedLinearBase) - return original_isinstance(obj, classinfo) - - builtins.isinstance = patched_isinstance - - def apply_monkey_patches(): """Apply all monkey patches in one place.""" from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod @@ -159,7 +166,6 @@ def apply_monkey_patches(): setattr(AWQMoEMethod, "apply", awq_moe_method_apply) -patch_vllm_linear_base_isinstance() # Apply patches when module is imported apply_monkey_patches()