Skip to content

Commit

Permalink
support speculative decoding kernel in sgl-kernel (#3373)
Browse files Browse the repository at this point in the history
Co-authored-by: Ying Sheng <[email protected]>
  • Loading branch information
zhyncs and Ying1123 authored Feb 7, 2025
1 parent 45c87e0 commit f9905d5
Show file tree
Hide file tree
Showing 13 changed files with 1,299 additions and 133 deletions.
584 changes: 482 additions & 102 deletions python/sglang/srt/speculative/build_eagle_tree.py

Large diffs are not rendered by default.

96 changes: 67 additions & 29 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,39 +258,77 @@ def generate_attn_arg_prefill(
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask

def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
)
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
dim=-1,
)
target_predict = predict[self.retrive_index]
candidates = draft_token[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
bs = self.retrive_cum_len.numel() - 1

max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
if batch.sampling_info.is_all_greedy:
# temp == 0
bs = self.retrive_cum_len.numel() - 1
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
)
target_predict = predict[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]

accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
else:
# temp > 0
bs = self.retrive_index.shape[0]
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
target_logits = logits_output.next_token_logits[self.retrive_index]
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(
target_probs, 0, dtype=torch.float32, device="cuda"
)
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
tree_speculative_sampling_target_only(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
uniform_samples=coins,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True,
)

new_accept_index = []
unfinished_index = []
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.3.post1"
version = "0.0.3.post2"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _get_version():
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative_sampling.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu",
Expand Down
6 changes: 6 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
build_tree_kernel,
build_tree_kernel_efficient,
custom_dispose,
custom_reduce,
fp8_scaled_mm,
Expand All @@ -31,6 +33,7 @@
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
)

__all__ = [
Expand All @@ -57,4 +60,7 @@
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
"tree_speculative_sampling_target_only",
"build_tree_kernel_efficient",
"build_tree_kernel",
]
209 changes: 209 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright (c) 2025 by SGLang team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len,
bool* tree_mask, int64_t* positions, int64_t* retrive_index,
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;

if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}

int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;

int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped.");
continue;
}

if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}

int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}

void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
int64_t depth, int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk), int32_t(depth), int32_t(draft_token_num));
}

// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask,
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;

if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}

int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
return;
}

int depends_order[10];

int cur_position = tid - 1;
while (true) {
depends_order[position] = cur_position + 1;
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}

int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
if (cur_position == draft_token_num) {
printf(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped.");
break;
}
}
positions[bid * draft_token_num + tid] = position + seq_len;

int is_leaf = 0;
for (int i = 1; i < draft_token_num; i++) {
if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
is_leaf++;
}
}
if (is_leaf == 1) {
for (int i = 0; i < position; i++) {
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
depends_order[i] + bid * draft_token_num;
}
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
}
}

void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

build_tree<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk),
int32_t(depth), int32_t(draft_token_num));
}
Loading

0 comments on commit f9905d5

Please sign in to comment.