diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6d144f84433..cc04bb4d420 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -98,6 +98,7 @@ def __init__( if ( "DeepseekV2ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures + or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py new file mode 100644 index 00000000000..9588eb87e03 --- /dev/null +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -0,0 +1,295 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Inference-only DeepSeek NextN Speculative Decoding.""" +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm import _custom_ops as ops + +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_to_tensor_quant, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() + + +class DeepseekModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + + self.decoder = DeepseekV2DecoderLayer( + config, 0, quant_config=quant_config, is_nextn=True + ) + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) + ) + + residual = None + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual + ) + + if not forward_batch.forward_mode.is_idle(): + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + + self.model = DeepseekModelNextN(config, quant_config) + + if global_server_args_dict["enable_dp_attention"]: + self.model.shared_head.head = ReplicatedLinear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + else: + self.model.shared_head.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.model.shared_head.head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supportted" + assert num_nextn_layers == self.config.num_hidden_layers + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + nextn_layer_prefix = "model.layers.0" + nextn_spec_weight_names = [ + "shared_head.head", + "shared_head.norm", + "eh_proj", + "embed_tokens", + "enorm", + "hnorm", + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if not name.startswith(nextn_layer_prefix): + continue + else: + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + if not global_server_args_dict["disable_mla"]: + self_attn = self.model.decoder.self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if is_hip_: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if is_hip_: + self_attn.w_scale *= 2.0 + + +EntryClass = [DeepseekV3ForCausalLMNextN] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9046f227da9..df4f9ed1430 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -519,6 +519,8 @@ def forward( # Triton: Use normal computation for prefill and use weight absorption for extend/decode if ( forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() and forward_batch.extend_prefix_lens.sum() == 0 ): return self.forward_normal(positions, hidden_states, forward_batch) @@ -680,6 +682,7 @@ def __init__( config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + is_nextn: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -731,7 +734,7 @@ def __init__( quant_config=quant_config, layer_id=layer_id, ) - if ( + if is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a8ab27cc910..b4c6a1224fb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,14 +262,17 @@ def __post_init__(self): ) # Speculative Decoding - if self.speculative_algorithm == "EAGLE": + if ( + self.speculative_algorithm == "EAGLE" + or self.speculative_algorithm == "NEXTN" + ): self.prefill_only_one_req = True self.disable_cuda_graph_padding = True self.disable_radix_cache = True self.disable_overlap_schedule = True self.chunked_prefill_size = -1 logger.info( - "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding." + f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding." ) # GGUF @@ -705,7 +708,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE"], + choices=["EAGLE", "NEXTN"], help="Speculative algorithm.", ) parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 8c24f2aa561..eb8e839f950 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -24,6 +24,7 @@ fast_topk, select_top_k_tokens, ) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm logger = logging.getLogger(__name__) @@ -57,11 +58,15 @@ def __init__( # Parse arguments self.topk = server_args.speculative_eagle_topk self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) self.server_args = server_args # Share the embedding and lm_head - embed, head = self.target_worker.model_runner.model.get_embed_and_head() - self.model_runner.model.set_embed_and_head(embed, head) + if not self.speculative_algorithm.is_nextn(): + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 5f156b837f9..3d352001881 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE = auto() + # NEXTN spec decoding is for DeepSeek V3/R1 + # currently it's implemented based on EAGLE + NEXTN = auto() + def is_none(self): return self == SpeculativeAlgorithm.NONE def is_eagle(self): - return self == SpeculativeAlgorithm.EAGLE + return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN + + def is_nextn(self): + return self == SpeculativeAlgorithm.NEXTN @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, + "NEXTN": SpeculativeAlgorithm.NEXTN, None: SpeculativeAlgorithm.NONE, } + if name is not None: + name = name.upper() return name_map[name] diff --git a/scripts/export_deepseek_nextn.py b/scripts/export_deepseek_nextn.py new file mode 100644 index 00000000000..3e72fee4f67 --- /dev/null +++ b/scripts/export_deepseek_nextn.py @@ -0,0 +1,113 @@ +""" +Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding. + +Usage: +python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN +""" + +import argparse +import json +import os +import shutil + +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig + + +def get_nexn_layer_id(config): + if not hasattr(config, "num_hidden_layers"): + raise ValueError("'num_hidden_layers' not found in model config.") + return config.num_hidden_layers + + +def update_and_save_config(config, output_dir): + new_config = config.to_dict() + new_config.update( + { + "num_hidden_layers": 0, + "architectures": ["DeepseekV3ForCausalLMNextN"], + } + ) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True) + + +def copy_non_safetensors_files(input_dir, output_dir): + for filename in os.listdir(input_dir): + src_file_path = os.path.join(input_dir, filename) + if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"): + dst_file_path = os.path.join(output_dir, filename) + shutil.copy2(src_file_path, dst_file_path) + print(f"All non-safetensors files have been copied to {output_dir}") + + +def export_nextn_layer_parameters(input_dir, output_dir, nexn_layer_id): + prefix = f"model.layers.{nexn_layer_id}" + output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors") + params = {} + for filename in os.listdir(input_dir): + if not filename.endswith(".safetensors"): + continue + + file_path = os.path.join(input_dir, filename) + print(f"Processing: {filename}") + + try: + with safe_open(file_path, framework="pt") as f: + matching_keys = [k for k in f.keys() if k.startswith(prefix)] + + if not matching_keys: + print(f" No parameters starting with '{prefix}' found") + continue + + for key in matching_keys: + new_key = key.replace(prefix, "model.layers.0") + params[new_key] = f.get_tensor(key) + + except Exception as e: + print(f" Error processing {filename}: {str(e)}") + + if params: + print(f"Saving {len(params)} parameters to {output_path}") + save_file(params, output_path) + else: + print("No matching parameters found.") + + # Update safetensors index + index_path = os.path.join(output_dir, "model.safetensors.index.json") + print(f"Updating safetensors index to {index_path}") + index_data = {"weight_map": {}} + for key in params: + index_data["weight_map"][key] = "nextn_layer_parameters.safetensors" + with open(index_path, "w") as f: + json.dump(index_data, f, indent=4) + + print("All done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Export NextN layer paramerters for DeepSeek-V3/R1" + ) + parser.add_argument( + "--input-dir", + type=str, + required=True, + help="Input HF model directory.", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output nextn model directory.", + ) + args = parser.parse_args() + + config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True) + assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported." + nextn_layer_id = get_nexn_layer_id(config) + os.makedirs(args.output_dir, exist_ok=True) + copy_non_safetensors_files(args.input_dir, args.output_dir) + update_and_save_config(config, args.output_dir) + export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)