From b17b685943946248b589f4f87cf64dbd46ad69e3 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Wed, 27 Nov 2024 15:42:20 +0800 Subject: [PATCH] remove vllm model_loader deps and support qqq quantization --- python/sglang/srt/layers/linear.py | 1 + .../srt/layers/quantization/__init__.py | 2 +- python/sglang/srt/layers/quantization/qqq.py | 300 ++++ python/sglang/srt/lora/lora.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 65 +- python/sglang/srt/model_loader/__init__.py | 50 + python/sglang/srt/model_loader/loader.py | 1255 +++++++++++++++++ python/sglang/srt/model_loader/utils.py | 40 + .../sglang/srt/model_loader/weight_utils.py | 646 +++++++++ python/sglang/srt/models/baichuan.py | 2 +- python/sglang/srt/models/chatglm.py | 2 +- python/sglang/srt/models/commandr.py | 2 +- python/sglang/srt/models/dbrx.py | 2 +- python/sglang/srt/models/deepseek.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/exaone.py | 2 +- python/sglang/srt/models/gemma.py | 2 +- python/sglang/srt/models/gemma2.py | 4 +- python/sglang/srt/models/gpt2.py | 2 +- python/sglang/srt/models/gpt_bigcode.py | 2 +- python/sglang/srt/models/grok.py | 4 +- python/sglang/srt/models/internlm2.py | 2 +- python/sglang/srt/models/llama.py | 2 +- .../sglang/srt/models/llama_classification.py | 2 +- python/sglang/srt/models/llama_embedding.py | 2 +- python/sglang/srt/models/llama_reward.py | 1 + python/sglang/srt/models/llava.py | 2 +- python/sglang/srt/models/llavavid.py | 2 +- python/sglang/srt/models/minicpm.py | 2 +- python/sglang/srt/models/minicpm3.py | 2 +- python/sglang/srt/models/mixtral.py | 2 +- python/sglang/srt/models/mixtral_quant.py | 2 +- python/sglang/srt/models/mllama.py | 2 +- python/sglang/srt/models/olmo.py | 2 +- python/sglang/srt/models/olmoe.py | 2 +- python/sglang/srt/models/qwen.py | 2 +- python/sglang/srt/models/qwen2.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 2 +- python/sglang/srt/models/qwen2_vl.py | 2 +- python/sglang/srt/models/registry.py | 93 ++ python/sglang/srt/models/stablelm.py | 2 +- .../sglang/srt/models/torch_native_llama.py | 2 +- python/sglang/srt/models/xverse.py | 2 +- python/sglang/srt/models/xverse_moe.py | 2 +- python/sglang/srt/models/yivl.py | 2 +- 45 files changed, 2427 insertions(+), 100 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/qqq.py create mode 100644 python/sglang/srt/model_loader/__init__.py create mode 100644 python/sglang/srt/model_loader/loader.py create mode 100644 python/sglang/srt/model_loader/utils.py create mode 100644 python/sglang/srt/model_loader/weight_utils.py create mode 100644 python/sglang/srt/models/registry.py diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 095164e1a13..f69058ff319 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -42,6 +42,7 @@ "Fp8LinearMethod", "MarlinLinearMethod", "GPTQLinearMethod", + "QQQLinearMethod", ] diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index a1bacdce036..8c7984bef6e 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -19,10 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.qqq import QQQConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/python/sglang/srt/layers/quantization/qqq.py b/python/sglang/srt/layers/quantization/qqq.py new file mode 100644 index 00000000000..7930240a423 --- /dev/null +++ b/python/sglang/srt/layers/quantization/qqq.py @@ -0,0 +1,300 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/quantization/qqq.py + +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter +from torchao.ops import marlin_qqq_gemm +from torchao.quantization.utils import dynamically_quantize_per_channel +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) + +from sglang.srt.layers.quantization.base_config import QuantizationConfig + +logger = logging.getLogger(__name__) + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + + +class QQQConfig(QuantizationConfig): + """Config class for QQQ + + Reference: https://arxiv.org/pdf/2406.09904 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + is_sym: bool = True, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.is_sym = is_sym + + # Verify + if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: + raise ValueError( + f"QQQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " + "are supported." + ) + if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QQQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " + "are supported." + ) + if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: + raise ValueError( + f"QQQ does not support is_sym = {self.is_sym}. " + f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported." + ) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by QQQ kernels. + self.tile_size = MARLIN_QQQ_TILE + + # Min out_features dim + self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = MARLIN_QQQ_MAX_PARALLEL + + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "QQQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size + ) + + @classmethod + def get_name(cls) -> str: + return "qqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QQQLinearMethod"]: + if isinstance(layer, LinearBase): + return QQQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QQQLinearMethod(LinearMethodBase): + """Linear method for QQQ. + + Args: + quant_config: The QQQ quantization config. + """ + + def __init__(self, quant_config: QQQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}" + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}." + ) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}." + ) + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2 + ) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError("Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition + * self.quant_config.tile_size + // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader, + ) + + s_channel = ChannelQuantScaleParameter( + data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1, + ) + + if self.quant_config.group_size == -1: + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, + ) + else: + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, + ) + + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **s_group_attr + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // self.quant_config.min_n_threads + ) * self.quant_config.max_parallel + + workspace = BasevLLMParameter( + data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), + weight_loader=weight_loader, + ) + + layer.register_parameter("B", qweight) + layer.register_parameter("s_channel", s_channel) + layer.register_parameter("s_group", s_group) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + s_ch = layer.s_channel + s_group = layer.s_group + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = s_ch.shape[1] + + x_int8, s_tok, _ = dynamically_quantize_per_channel( + x_2d, quant_min=-127, quant_max=127, target_dtype=torch.int8 + ) + # TODO(HandH1998): As the `dynamically_quantize_per_channel` function in torchao doesn't support defining the `scale_dtype`, + # we have to convert `s_tok` to `torch.float32`, which is required by `marlin_qqq_gemm`. Remove it when torchao supports defining the `scale_dtype`. + s_tok = s_tok.to(torch.float32) + + output_2d = marlin_qqq_gemm( + x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 9f21df7786a..839d10222e2 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -31,7 +31,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.loader import DefaultModelLoader from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -40,6 +39,7 @@ RowParallelLinear, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_loader.loader import DefaultModelLoader class BaseLayerWithLoRA(nn.Module): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7c1c51a8fb2..34660fdaca8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -19,12 +19,9 @@ import inspect import json import logging -import pkgutil -from functools import lru_cache -from typing import Optional, Type +from typing import Optional import torch -import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -34,8 +31,6 @@ set_custom_all_reduce, ) from vllm.distributed.parallel_state import in_the_same_node_as -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import ModelRegistry from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend @@ -52,6 +47,7 @@ ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader import get_model from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -293,7 +289,6 @@ def load_model(self): load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) - monkey_patch_vllm_model_config() self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( @@ -318,12 +313,12 @@ def load_model(self): def update_weights(self, model_path: str, load_format: str): """Update weights in-place.""" - from vllm.model_executor.model_loader.loader import ( + from sglang.srt.model_loader.loader import ( DefaultModelLoader, device_loading_context, get_model_loader, ) - from vllm.model_executor.model_loader.utils import set_default_torch_dtype + from sglang.srt.model_loader.utils import set_default_torch_dtype logger.info( f"Update weights begin. " @@ -694,55 +689,3 @@ def model_is_mrope(self) -> bool: if rope_scaling is None: return False return rope_scaling.get("type", None) == "mrope" - - -@lru_cache() -def import_model_classes(): - model_arch_name_to_cls = {} - package_name = "sglang.srt.models" - package = importlib.import_module(package_name) - for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): - if not ispkg: - try: - module = importlib.import_module(name) - except Exception as e: - logger.warning(f"Ignore import error when loading {name}. {e}") - if crash_on_warnings(): - raise ValueError(f"Ignore import error when loading {name}. {e}") - continue - if hasattr(module, "EntryClass"): - entry = module.EntryClass - if isinstance( - entry, list - ): # To support multiple model classes in one module - for tmp in entry: - assert ( - tmp.__name__ not in model_arch_name_to_cls - ), f"Duplicated model implementation for {tmp.__name__}" - model_arch_name_to_cls[tmp.__name__] = tmp - else: - assert ( - entry.__name__ not in model_arch_name_to_cls - ), f"Duplicated model implementation for {entry.__name__}" - model_arch_name_to_cls[entry.__name__] = entry - - return model_arch_name_to_cls - - -def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: - model_arch_name_to_cls = import_model_classes() - - if model_arch not in model_arch_name_to_cls: - raise ValueError( - f"Unsupported architectures: {model_arch}. " - f"Supported list: {list(model_arch_name_to_cls.keys())}" - ) - return model_arch_name_to_cls[model_arch] - - -# Monkey patch model loader -setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) -setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False) -setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False) -setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False) -setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False) diff --git a/python/sglang/srt/model_loader/__init__.py b/python/sglang/srt/model_loader/__init__.py new file mode 100644 index 00000000000..df7e74143d9 --- /dev/null +++ b/python/sglang/srt/model_loader/__init__.py @@ -0,0 +1,50 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/__init__.py + +from typing import Optional + +from torch import nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, +) + +from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader +from sglang.srt.model_loader.utils import ( + get_architecture_class_name, + get_model_architecture, +) + + +def get_model( + *, + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig +) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + + +__all__ = [ + "get_model", + "get_model_loader", + "BaseModelLoader", + "get_architecture_class_name", + "get_model_architecture", +] diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py new file mode 100644 index 00000000000..f0f6f02dd29 --- /dev/null +++ b/python/sglang/srt/model_loader/loader.py @@ -0,0 +1,1255 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py + +# ruff: noqa: SIM117 +import collections +import dataclasses +import fnmatch +import glob +import json +import logging +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast + +import gguf +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoadFormat, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.model_executor.models import ( + has_inner_state, + supports_lora, + supports_multimodal, +) +from vllm.platforms import current_platform +from vllm.utils import is_pin_memory_available + +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_loader.utils import ( + get_model_architecture, + set_default_torch_dtype, +) +from sglang.srt.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + get_gguf_extra_tensor_names, + get_quant_config, + gguf_quant_weights_iterator, + initialize_dummy_weights, + np_cache_weights_iterator, + pt_weights_iterator, + safetensors_weights_iterator, +) +from sglang.srt.utils import set_weight_attrs + + +@contextmanager +def device_loading_context(module: torch.nn.Module, target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = logging.getLogger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig +) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + return quant_config + return None + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig] = None, +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs: Dict[str, Any] = {} + + if supports_lora(model_class): + # lora_config=None is used to disable LoRA + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github." + ) + + if supports_multimodal(model_class): + assert multimodal_config is not None + + extra_kwargs["multimodal_config"] = multimodal_config + + if has_inner_state(model_class) and scheduler_config: + extra_kwargs["scheduler_config"] = scheduler_config + + return extra_kwargs + + +def build_model( + model_class: Type[nn.Module], + hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + *, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig], +) -> nn.Module: + extra_kwargs = _get_model_initialization_kwargs( + model_class, lora_config, multimodal_config, scheduler_config + ) + + return model_class( + config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + **extra_kwargs, + ) + + +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig, + scheduler_config: Optional[SchedulerConfig] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class, _ = get_model_architecture(model_config) + + return build_model( + model_class, + model_config.hf_config, + cache_config=cache_config, + quant_config=_get_quantization_config(model_config, load_config), + lora_config=lora_config, + multimodal_config=model_config.multimodal_config, + scheduler_config=scheduler_config, + ) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str] + ) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = ( + self._maybe_download_from_modelscope(model_name_or_path, revision) + or model_name_or_path + ) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt + ) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + ) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()) + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights( + model_config.model, model_config.revision, fall_back_to_pt=True + ) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model( + model_config, + self.load_config, + lora_config, + cache_config, + scheduler_config, + ) + + model.load_weights(self._get_all_weights(model_config, model)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, + self.load_config, + lora_config, + cache_config, + ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_state.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ( + {} + if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy() + ) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}" + ) + + @staticmethod + def _filter_subtensors(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list) + ) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + from safetensors.torch import safe_open + from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights( + model_config.model, model_config.revision + ) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, cache_config + ) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!" + ) + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + default_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ".fc1.", + ".fc2.", + ".dense.", + ".query_key_value.", + ".qkv_proj.", + ".dense_h_to_4h.", + ".dense_4h_to_h.", + ".out_proj.", + ] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + self.target_modules = [] + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download( + repo_id=qlora_adapter, filename=file + ) + break + + if not config_file_path: + raise ValueError(f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob(os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError(f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision + ) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.44.0": + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.44.0." + ) + except ImportError as err: + raise ImportError( + "Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " + "bitsandbytes quantizer." + ) from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision + ) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return ( + self._quantized_8bit_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + else: + return ( + self._quantized_4bit_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + + return ( + self._unquantized_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), + quant_state_dict, + ) + + def _quantized_8bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if not weight_name.endswith((".weight", ".bias")): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith((".weight", ".bias")): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if not weight_name.endswith((".weight", ".bias")): + continue + + if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( + f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict + ): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors + ): + + if any( + target_module in weight_name for target_module in self.target_modules + ) and weight_name.endswith(".weight"): + weight_name = weight_name.replace(".weight", ".qweight") + + if any( + module in weight_name + for module in self.column_parallel_weights_modules + ): + + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, compress_statistics=True, quant_type="nf4" + ) + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}." + ) + + if not hasattr(model, "bitsandbytes_stacked_params_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet." + ) + + if len(self.target_modules) == 0: + if hasattr(model, "default_bitsandbytes_target_modules"): + self.target_modules = model.default_bitsandbytes_target_modules + else: + self.target_modules = self.default_target_modules + + if hasattr(model, "column_parallel_weights_modules"): + self.column_parallel_weights_modules = model.column_parallel_weights_modules + else: + self.column_parallel_weights_modules = [] + + self.model_type = type(model).__name__ + + logger.info( + "Loading weights with BitsAndBytes quantization. " " May take a while ..." + ) + + quant_config = getattr(model_config.hf_config, "quantization_config", None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization" + ) + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP." + ) + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get("load_in_8bit", False) + + qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator( + model_config.model, model_config.revision, pre_quant, load_8bit + ) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace(shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model." + ) + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[ + non_stacked_param_name + ] + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError(f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod(quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)} + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, cache_config + ) + + self._load_weights(model_config, model) + + return model.eval() + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config(config) + state_dict = dummy_model.state_dict() + + gguf_to_hf_name_map = {} + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map + ): + model_config.hf_config.update({"tie_word_embeddings": True}) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, cache_config + ) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map) + ) + return model + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py new file mode 100644 index 00000000000..3e603131ea2 --- /dev/null +++ b/python/sglang/srt/model_loader/utils.py @@ -0,0 +1,40 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/utils.py + +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn +from vllm.config import ModelConfig + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + from sglang.srt.models.registry import ModelRegistry + + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"] + + if ( + model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures + ): + architectures = ["QuantMixtralForCausalLM"] + + return ModelRegistry.resolve_model_cls(architectures) + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py new file mode 100644 index 00000000000..50097ac6506 --- /dev/null +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -0,0 +1,646 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/weight_utils.py + +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import logging +import os +import tempfile +from collections import defaultdict +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import filelock +import gguf +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm +from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config + +logger = logging.getLogger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer""" + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """ + ) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: + + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError(f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}" + ) + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}." + ) + + return quant_cls.from_config(config) + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=index_file, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have index_file. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", index_file) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: List[str], hf_folder: str, index_file: str +) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name, "r") as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def np_cache_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names: List[str] = [] + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> List[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader( + param: torch.Tensor, loaded_weight: torch.Tensor +) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] +) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + if current_platform.is_tpu(): + # XLA device does not support torch.Generator() + param.uniform_(low, high) + continue + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + print_warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded." + ) + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded." + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 0e5e3b9ade8..40b06b6ee00 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -34,7 +34,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -46,6 +45,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 05ce17a6b10..30fcc2cb958 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -23,7 +23,6 @@ from torch.nn import LayerNorm from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul @@ -41,6 +40,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader LoraConfig = None diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index d4018be88a1..8fca8bf3f7e 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -49,7 +49,6 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -62,6 +61,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import set_weight_attrs diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b8dad0248aa..20101309ba7 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -25,7 +25,6 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.fused_moe_triton import fused_moe @@ -43,6 +42,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import set_weight_attrs diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index cdebafa2ff6..f0a80a977d5 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -27,7 +27,6 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import fused_moe @@ -46,6 +45,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class DeepseekMLP(nn.Module): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 85467c12c90..c6b0e4a959a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -28,7 +28,6 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE @@ -48,6 +47,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index c097e00ad28..c434c0d2dbb 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -22,7 +22,6 @@ from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -39,6 +38,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class ExaoneGatedMLP(nn.Module): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index a53fad95803..2dd140cda69 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -24,7 +24,6 @@ from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -38,6 +37,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GemmaMLP(nn.Module): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 0fa6a539352..9affb01a042 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -23,9 +23,6 @@ from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( @@ -38,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 8d988fe8ea8..3d8e2edc73d 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -26,7 +26,6 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader # from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( @@ -39,6 +38,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GPT2Attention(nn.Module): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 03597fa7343..cb852aa9b17 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -23,7 +23,6 @@ from transformers import GPTBigCodeConfig from vllm.config import LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( @@ -36,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class GPTBigCodeAttention(nn.Module): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index f8326c72ddb..4807f74b0af 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -28,8 +28,6 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.loader import DefaultModelLoader -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_grok import FusedMoE from sglang.srt.layers.layernorm import RMSNorm @@ -46,6 +44,8 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.loader import DefaultModelLoader +from sglang.srt.model_loader.weight_utils import default_weight_loader class Grok1MoE(nn.Module): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 59ff6d1e2dd..a222c7585ef 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -21,7 +21,6 @@ from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -38,6 +37,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class InternLM2MLP(nn.Module): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f7267..aae14ac13eb 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -23,7 +23,6 @@ from transformers import LlamaConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -43,6 +42,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index c22b68d11e5..302d697a4f2 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -17,11 +17,11 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index da43d03fcaa..4278e21e8a8 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -3,10 +3,10 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.model_executor.model_runner import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaModel diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 5eb2daae637..7322e45a917 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -21,6 +21,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 780bf36b5d9..61174d943b9 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -29,7 +29,6 @@ SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import ImageInputs @@ -39,6 +38,7 @@ unpad_image_shape, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index c06ef876954..6a2c4ea60cf 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -20,11 +20,11 @@ from torch import nn from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 239cfb6fcc6..40f67953c71 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -20,7 +20,6 @@ from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -37,6 +36,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MiniCPMMLP(nn.Module): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 6f53f2974fb..5bf84b95963 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -27,7 +27,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -40,6 +39,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 98d5ab332a0..1e2eecf8c6c 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -23,7 +23,6 @@ from transformers import MixtralConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm @@ -42,6 +41,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MixtralMoE(nn.Module): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index d15a389a841..373ea2a3af8 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -29,7 +29,6 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -45,6 +44,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class MixtralMLP(nn.Module): diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 63bbfdb7ebe..a80cd880894 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -15,7 +15,6 @@ _prepare_aspect_ratio_attention_mask, ) from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.layernorm import RMSNorm @@ -34,6 +33,7 @@ ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 80fd64a53a7..05afc43e1a6 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -22,7 +22,6 @@ from transformers import OlmoConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -38,6 +37,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 407eb98cb3e..ec0bb86fb90 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -34,7 +34,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.utils import print_warning_once from sglang.srt.layers.activation import SiluAndMul @@ -48,6 +47,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 4c182902657..8adb965c762 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -22,7 +22,6 @@ from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -39,6 +38,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class QWenMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 634ce1cf166..6d15ba2151f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -22,7 +22,6 @@ from torch import nn from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -40,6 +39,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers Qwen2Config = None diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index febd6d74843..39edb53c0c9 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -27,7 +27,6 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.fused_moe_triton import FusedMoE @@ -48,6 +47,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class Qwen2MoeMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 3d387624369..54aa8d5ff17 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -35,7 +35,6 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor @@ -49,6 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model logger = init_logger(__name__) diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py new file mode 100644 index 00000000000..35e5dc3f0d6 --- /dev/null +++ b/python/sglang/srt/models/registry.py @@ -0,0 +1,93 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/models/registry.py + +import importlib +import logging +import pkgutil +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict) + + def get_supported_archs(self) -> List[str]: + return list(self.models.keys()) + + def _raise_for_unsupported(self, architectures: List[str]): + all_supported_archs = self.get_supported_archs() + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) + + def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in self.models: + return None + + return self.models[model_arch] + + def _normalize_archs( + self, + architectures: Union[str, List[str]], + ) -> List[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return architectures + + def resolve_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[Type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + try: + module = importlib.import_module(name) + except Exception as e: + logger.warning(f"Ignore import error when loading {name}. " f"{e}") + continue + if hasattr(module, "EntryClass"): + entry = module.EntryClass + if isinstance( + entry, list + ): # To support multiple model classes in one module + for tmp in entry: + assert ( + tmp.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {tmp.__name__}" + model_arch_name_to_cls[tmp.__name__] = tmp + else: + assert ( + entry.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {entry.__name__}" + model_arch_name_to_cls[entry.__name__] = entry + + return model_arch_name_to_cls + + +ModelRegistry = _ModelRegistry(import_model_classes()) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9fa2ab34330..ffa40cf6d39 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -26,7 +26,6 @@ from transformers import PretrainedConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( @@ -42,6 +41,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class StablelmMLP(nn.Module): diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index b9451d59152..37ed78709e1 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -52,7 +52,6 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -66,6 +65,7 @@ ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index fb7e14a0efd..8d5e932fa93 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -30,7 +30,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -40,6 +39,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.model_runner import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class XverseMLP(nn.Module): diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index c6458f7f503..e3ec713392e 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -32,7 +32,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor @@ -43,6 +42,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader class XverseMLP(nn.Module): diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 6f1610e5254..bfb9a2c1c2b 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -18,9 +18,9 @@ import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llava import LlavaLlamaForCausalLM