Skip to content

Commit

Permalink
Resolve LinearBase circular dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengy001 committed Feb 25, 2025
1 parent 6ce9dbe commit 3f09b46
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 67 deletions.
7 changes: 4 additions & 3 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers import PretrainedConfig

from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.layers.quantization import QUANTIZATION_METHODS, get_quantization_config
from sglang.srt.utils import get_bool_env_var, is_hip

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,7 +218,7 @@ def _parse_quant_hf_config(self):

# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq",
"gptq",
Expand Down Expand Up @@ -253,7 +253,8 @@ def _verify_quantization(self) -> None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
for name in QUANTIZATION_METHODS:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def patch_vllm_linear_base_isinstance():
import builtins

from vllm.model_executor.layers.linear import LinearBase

from sglang.srt.layers.linear import LinearBase as PatchedLinearBase

original_isinstance = builtins.isinstance

def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
return original_isinstance(obj, classinfo)

builtins.isinstance = patched_isinstance


patch_vllm_linear_base_isinstance()
134 changes: 70 additions & 64 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,88 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Callable, Dict, Optional, Type
from typing import Callable, Dict, List, Optional, Type

import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"blockwise_int8": BlockInt8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"modelopt": ModelOptFp8Config,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"w8a8_int8": W8A8Int8Config,

QUANTIZATION_METHODS: List[str] = {
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"blockwise_int8",
"fbgemm_fp8",
"marlin",
"modelopt",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"qqq",
"experts_int8",
"w8a8_int8",
}


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
return QUANTIZATION_METHODS[quantization]
raise ValueError(f"Invalid quantization method: {quantization}.")

from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config,
)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

from .blockwise_int8 import BlockInt8Config
from .fp8 import Fp8Config
from .modelopt_quant import ModelOptFp8Config
from .w8a8_int8 import W8A8Int8Config

method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"blockwise_int8": BlockInt8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"modelopt": ModelOptFp8Config,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"w8a8_int8": W8A8Int8Config,
}

return method_to_config[quantization]


def gptq_get_quant_method(self, layer, prefix):
Expand Down Expand Up @@ -133,23 +157,6 @@ def awq_moe_method_apply(
)


def patch_vllm_linear_base_isinstance():
import builtins

from vllm.model_executor.layers.linear import LinearBase

from sglang.srt.layers.linear import LinearBase as PatchedLinearBase

original_isinstance = builtins.isinstance

def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
return original_isinstance(obj, classinfo)

builtins.isinstance = patched_isinstance


def apply_monkey_patches():
"""Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
Expand All @@ -159,7 +166,6 @@ def apply_monkey_patches():
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)


patch_vllm_linear_base_isinstance()
# Apply patches when module is imported
apply_monkey_patches()

Expand Down

0 comments on commit 3f09b46

Please sign in to comment.