diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index e0ac9fe0bb6..87896cb6c2a 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -1,124 +1,175 @@ -import cutex +# NOTE: Please run this file to make sure the test cases are correct. + +from typing import List + import torch -# parent_table [bs,topk*depth+)] -# selected_index [bs,draft_token_num-1)] -# verified_seq_len [bs] -# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] -# positions [bs*draft_token] -# retrive_index [b, draft_token, depth+2] -kernels = cutex.SourceModule( - """ -//cuda -__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, - Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { - int bid = blockIdx.x; - int tid = threadIdx.x; - if (tid >= draft_token_num){ - return; - } - int seq_tree_idx = draft_token_num * draft_token_num * bid; - for(int i=0; i torch.Tensor: - predict = torch.argmax(logits_output.next_token_logits, dim=-1) - predict = torch.cat( - [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 - ) draft_token = torch.cat( - [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], + [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1, ) - target_predict = predict[self.retrive_index] candidates = draft_token[self.retrive_index] - # logits = logits_output.next_token_logits[self.retrive_index] - # target_predict = torch.argmax(logits[:, :-1], dim=-1) - accept_mask = candidates[:, 1:] == target_predict[:, :-1] - accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) - bs = self.retrive_cum_len.numel() - 1 - - max_draft_len = self.retrive_index.shape[-1] - accept_index = torch.full( - (bs, max_draft_len), -1, dtype=torch.long, device="cuda" - ) - accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") - extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") - eagle_verify_retrive[(bs,)]( - self.retrive_index.contiguous(), - accept_mask.contiguous(), - self.retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_draft_len, - self.draft_token_num, - triton.next_power_of_2(max_draft_len), - ) + if batch.sampling_info.is_all_greedy: + # temp == 0 + bs = self.retrive_cum_len.numel() - 1 + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat( + [predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1 + ) + target_predict = predict[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.int32, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + else: + # temp > 0 + bs = self.retrive_index.shape[0] + predict_shape = list(logits_output.next_token_logits.shape)[:-1] + predict_shape[-1] += 1 + target_logits = logits_output.next_token_logits[self.retrive_index] + predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda") + accept_index = torch.full( + (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1) + target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) + draft_probs = torch.full_like( + target_probs, 0, dtype=torch.float32, device="cuda" + ) + coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") + tree_speculative_sampling_target_only( + predicts=predict, # mutable + accept_index=accept_index, # mutable + accept_token_num=accept_length, # mutable + candidates=candidates.to(torch.int32), + retrive_index=self.retrive_index.to(torch.int32), + retrive_next_token=self.retrive_next_token.to(torch.int32), + retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), + uniform_samples=coins, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict[ + "speculative_accept_threshold_acc" + ], + deterministic=True, + ) new_accept_index = [] unfinished_index = [] diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index bb7d6943348..85d926dff46 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3.post1" +version = "0.0.3.post2" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 9a93ae99229..8b2f3775180 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -99,6 +99,8 @@ def _get_version(): "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", + "src/sgl-kernel/csrc/eagle_utils.cu", + "src/sgl-kernel/csrc/speculative_sampling.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index ff41db8e43f..13bae677a09 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -10,6 +10,8 @@ from sgl_kernel.ops import ( apply_rope_with_cos_sin_cache_inplace, bmm_fp8, + build_tree_kernel, + build_tree_kernel_efficient, custom_dispose, custom_reduce, fp8_scaled_mm, @@ -31,6 +33,7 @@ top_k_renorm_prob, top_k_top_p_sampling_from_probs, top_p_renorm_prob, + tree_speculative_sampling_target_only, ) __all__ = [ @@ -57,4 +60,7 @@ "top_k_renorm_prob", "top_k_top_p_sampling_from_probs", "top_p_renorm_prob", + "tree_speculative_sampling_target_only", + "build_tree_kernel_efficient", + "build_tree_kernel", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu b/sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu new file mode 100644 index 00000000000..af44261cc18 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2025 by SGLang team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = +// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, +// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] +__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, + bool* tree_mask, int64_t* positions, int64_t* retrive_index, + int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth, + int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for (int i = 0; i < bid; i++) { + seq_tree_idx += verified_seq_len[i] * draft_token_num; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + for (int i = 0; i < draft_token_num - 1; i++) { + tree_mask[token_tree_idx + i] = false; + } + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token " + "will be dropped."); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + tree_mask[token_tree_idx + cur_position] = true; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + +void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, + at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, + at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, + int64_t depth, int64_t draft_token_num) { + // TODO (ying) check shape + // TODO (ying) check type + int bs = parent_list.size(0); + dim3 grid(bs); + dim3 block(draft_token_num); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + build_tree_efficient<<>>( + static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), int32_t(depth), int32_t(draft_token_num)); +} + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = +// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, +// draft_token, depth + 2] +__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask, + int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for (int i = 0; i < bid; i++) { + seq_tree_idx += verified_seq_len[i] * draft_token_num; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + for (int i = 0; i < draft_token_num - 1; i++) { + tree_mask[token_tree_idx + i] = false; + } + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num; + return; + } + + int depends_order[10]; + + int cur_position = tid - 1; + while (true) { + depends_order[position] = cur_position + 1; + position += 1; + tree_mask[token_tree_idx + cur_position] = true; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; cur_position++) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + if (cur_position == draft_token_num) { + printf( + "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token " + "will be dropped."); + break; + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + + int is_leaf = 0; + for (int i = 1; i < draft_token_num; i++) { + if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) { + is_leaf++; + } + } + if (is_leaf == 1) { + for (int i = 0; i < position; i++) { + retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] = + depends_order[i] + bid * draft_token_num; + } + retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num; + } +} + +void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, + at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, + int64_t depth, int64_t draft_token_num) { + // TODO (ying) check shape + // TODO (ying) check type + int bs = parent_list.size(0); + dim3 grid(bs); + dim3 block(draft_token_num); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + build_tree<<>>( + static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), int32_t(topk), + int32_t(depth), int32_t(draft_token_num)); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu new file mode 100644 index 00000000000..a249455101b --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2025 by SGLang team. + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +// predicts: [tot_num_draft_tokens] +// accept_index: [bs, num_spec_step] +// accept_token_num: [bs] +// candidates: [bs, num_draft_tokens] +// retrive_index: [bs, num_draft_tokens] +// retrive_next_token: [bs, num_draft_tokens] +// retrive_next_sibling: [bs, num_draft_tokens] +// uniform_samples: [bs, num_draft_tokens] +// target_probs: [bs, num_draft_tokens, vocab_size] +void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, at::Tensor retrive_index, + at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, + bool deterministic, int64_t cuda_stream = 0) { + CHECK_INPUT(candidates); + CHECK_INPUT(retrive_index); + CHECK_INPUT(retrive_next_token); + CHECK_INPUT(retrive_next_sibling); + CHECK_INPUT(uniform_samples); + CHECK_INPUT(target_probs); + auto device = target_probs.device(); + CHECK_EQ(candidates.device(), device); + CHECK_EQ(retrive_index.device(), device); + CHECK_EQ(retrive_next_token.device(), device); + CHECK_EQ(retrive_next_sibling.device(), device); + CHECK_EQ(uniform_samples.device(), device); + CHECK_EQ(target_probs.device(), device); + CHECK_DIM(1, predicts); + CHECK_DIM(2, accept_index); + CHECK_DIM(1, accept_token_num); + CHECK_DIM(2, candidates); + CHECK_DIM(2, retrive_index); + CHECK_DIM(2, retrive_next_token); + CHECK_DIM(2, retrive_next_sibling); + CHECK_DIM(2, uniform_samples); + CHECK_DIM(3, target_probs); + CHECK_DIM(3, draft_probs); + unsigned int batch_size = uniform_samples.size(0); + unsigned int num_spec_step = accept_index.size(1); + unsigned int num_draft_tokens = candidates.size(1); + unsigned int vocab_size = target_probs.size(2); + CHECK_EQ(batch_size, candidates.size(0)); + CHECK_EQ(batch_size, retrive_index.size(0)); + CHECK_EQ(batch_size, retrive_next_token.size(0)); + CHECK_EQ(batch_size, retrive_next_sibling.size(0)); + CHECK_EQ(batch_size, target_probs.size(0)); + CHECK_EQ(num_draft_tokens, retrive_index.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_token.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1)); + CHECK_EQ(num_draft_tokens, uniform_samples.size(1)); + CHECK_EQ(num_draft_tokens, target_probs.size(1)); + CHECK_EQ(vocab_size, target_probs.size(2)); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + if (predicts.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32)."); + } + if (accept_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32)."); + } + if (accept_token_num.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); + } + if (candidates.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32)."); + } + if (retrive_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32)."); + } + if (retrive_next_token.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32)."); + } + if (retrive_next_sibling.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32)."); + } + if (uniform_samples.scalar_type() != at::kFloat) { + throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32)."); + } + if (target_probs.scalar_type() != at::kFloat) { + throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32)."); + } + if (draft_probs.scalar_type() != at::kFloat) { + throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32)."); + } + + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( + static_cast(predicts.data_ptr()), static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), static_cast(uniform_samples.data_ptr()), + static_cast(target_probs.data_ptr()), static_cast(draft_probs.data_ptr()), batch_size, + num_spec_step, num_draft_tokens, vocab_size, deterministic, stream); + + TORCH_CHECK(status == cudaSuccess, + "TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh b/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh new file mode 100644 index 00000000000..b9a32d2a90e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2025 by SGLang team. + * Copyright (c) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SPECULATIVE_SAMPLING_CUH_ +#define SPECULATIVE_SAMPLING_CUH_ + +#include + +#include + +namespace flashinfer { + +namespace sampling { + +using namespace cub; + +template +__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index, + IdType* accept_token_num, // mutable + IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, + IdType* retrive_next_sibling, DType* uniform_samples, + DType* target_probs, DType* draft_probs, uint32_t batch_size, + uint32_t num_speculative_tokens, uint32_t num_draft_tokens, + uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + extern __shared__ __align__(alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>(smem_sampling); + + DType prob_acc = 0.0; + uint32_t cur_prob_offset = bx * num_draft_tokens * d; + DType coin = uniform_samples[bx * num_draft_tokens]; + IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; + uint32_t num_accepted_tokens = 0; + IdType cur_index = 0; + + for (uint32_t j = 1; j < num_speculative_tokens; ++j) { + cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; + while (cur_index != -1) { + IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + prob_acc += target_probs[cur_prob_offset + draft_token_id]; + + if (coin < prob_acc) { + // accept token + prob_acc = 0.; + cur_prob_offset = (bx * num_draft_tokens + cur_index) * d; + coin = uniform_samples[bx * num_draft_tokens + cur_index]; + predicts[last_accepted_retrive_idx] = draft_token_id; + ++num_accepted_tokens; + accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index; + last_accepted_retrive_idx = draft_index; + break; + } else { + // FIXME: leverage draft probs + draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id]; + cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index]; + } + } + if (cur_index == -1) break; + } + accept_token_num[bx] = num_accepted_tokens; + + // sample from relu(target_probs - draft_probs) + DType sum_relu_q_minus_p(0); + vec_t q_vec, p_vec; + DType relu_q_minus_p[VEC_SIZE]; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + q_vec.fill(DType(0)); + p_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + if (num_accepted_tokens != num_speculative_tokens - 1) { + // there is no draft_probs for the bonus token + p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0)); + } + sum_relu_q_minus_p += BlockReduce(temp_storage.block_prim.reduce) + .Sum(relu_q_minus_p); + __syncthreads(); + } + if (tx == 0) { + temp_storage.block_aggregate.value = sum_relu_q_minus_p; + } + // init the first rejected token to (d - 1) + temp_storage.sampled_id = d - 1; + __syncthreads(); + sum_relu_q_minus_p = temp_storage.block_aggregate.value; + DType u = coin * sum_relu_q_minus_p; + + DType aggregate_relu_q_minus_p(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + q_vec.fill(DType(0)); + p_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + if (num_accepted_tokens != num_speculative_tokens - 1) { + // there is no draft_probs for the bonus token + p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } + + vec_t relu_q_minus_p_vec; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0)); + } + + DeviceSamplingFromProb( + i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage); + if (aggregate_relu_q_minus_p > u) { + break; + } + } + __syncthreads(); + // set the first rejected token + predicts[last_accepted_retrive_idx] = temp_storage.sampled_id; + // value at not used indices are undefined +} + +template +cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids, + IdType* output_accepted_token_num, // mutable + IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, + IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, + DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, + uint32_t num_draft_tokens, uint32_t d, bool deterministic, + cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&predicts, + &output_token_ids, + &output_accepted_token_num, + &candidates, + &retrive_index, + &retrive_next_token, + &retrive_next_sibling, + &uniform_samples, + &target_probs, + &draft_probs, + &batch_size, + &num_speculative_tokens, + &num_draft_tokens, + &d}; + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TreeSpeculativeSamplingTargetOnly; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; +} + +} // namespace sampling + +} // namespace flashinfer + +#endif // SPECULATIVE_SAMPLING_CUH_ diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 1fdcc9c35ae..266935ad9fe 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -127,3 +127,19 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, int64_t cuda_stream); + +void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, at::Tensor retrive_index, + at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, + bool deterministic = true, int64_t cuda_stream = 0); + +void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, + at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, + at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, + int64_t depth, int64_t draft_token_num); + +void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, + at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, + int64_t depth, int64_t draft_token_num); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 5aa484ff54d..55f83bddd1a 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -495,3 +495,87 @@ def min_p_sampling_from_probs( return _min_p_sampling_from_probs_internal( probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic ) + + +def tree_speculative_sampling_target_only( + predicts: torch.Tensor, # mutable + accept_index: torch.Tensor, # mutable + accept_token_num: torch.Tensor, # mutable + candidates: torch.Tensor, + retrive_index: torch.Tensor, + retrive_next_token: torch.Tensor, + retrive_next_sibling: torch.Tensor, + uniform_samples: torch.Tensor, + target_probs: torch.Tensor, + draft_probs: torch.Tensor, + deterministic: bool = True, +) -> None: + with predicts.device as device: + torch.ops.sgl_kernels.tree_speculative_sampling_target_only( + predicts, + accept_index, + accept_token_num, + candidates, + retrive_index, + retrive_next_token, + retrive_next_sibling, + uniform_samples, + target_probs, + draft_probs, + deterministic, + _get_cuda_stream(device), + ) + + +def build_tree_kernel_efficient( + parent_list: torch.Tensor, + selected_index: torch.Tensor, + verified_seq_len: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_next_token: torch.Tensor, + retrive_next_sibling: torch.Tensor, + topk: int, + depth: int, + draft_token_num: int, +) -> None: + with parent_list.device as device: + torch.ops.sgl_kernels.build_tree_kernel_efficient( + parent_list, + selected_index, + verified_seq_len, + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + topk, + depth, + draft_token_num, + ) + + +def build_tree_kernel( + parent_list: torch.Tensor, + selected_index: torch.Tensor, + verified_seq_len: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + topk: int, + depth: int, + draft_token_num: int, +) -> None: + with parent_list.device as device: + torch.ops.sgl_kernels.build_tree_kernel( + parent_list, + selected_index, + verified_seq_len, + tree_mask, + positions, + retrive_index, + topk, + depth, + draft_token_num, + ) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index aaed142a1ef..e964b32683e 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -130,6 +130,29 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + + // tree spec decode + m.def( + "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " + "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " + "bool deterministic, int cuda_stream) -> ()"); + m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); + + // eagle build tree + m.def( + "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " + "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! " + "retrive_next_sibling, " + "int topk, int depth, int draft_token_num) -> ()"); + m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); + + // eagle build tree + m.def( + "build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " + "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, " + "int topk, int depth, int draft_token_num) -> ()"); + m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_speculative_sampling.py b/sgl-kernel/tests/test_speculative_sampling.py new file mode 100644 index 00000000000..545c3725ac2 --- /dev/null +++ b/sgl-kernel/tests/test_speculative_sampling.py @@ -0,0 +1,104 @@ +import torch +import torch.nn.functional as F +from sgl_kernel import tree_speculative_sampling_target_only + + +def test_tree_speculative_sampling_target_only(): + candidates = torch.tensor( + [ + [0, 1, 2, 3, 4, 5], + [7, 8, 9, 10, 11, 12], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_index = torch.tensor( + [ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_next_token = torch.tensor( + [ + [1, 2, -1, 4, 5, -1], + [4, 2, 3, -1, 5, -1], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_next_sibling = torch.tensor( + [ + [-1, 3, -1, -1, -1, -1], + [-1, -1, -1, -1, 1, -1], + ], + dtype=torch.int32, + device="cuda", + ) + + target_logits = torch.zeros((2, 6, 20), dtype=torch.float32, device="cuda") + target_logits[0, 0, 3] = 10 + target_logits[0, 3, 4] = 10 + target_logits[0, 4, 5] = 10 + target_logits[1, 0, 11] = 10 + target_logits[1, 4, 12] = 10 + for i in range(target_logits.shape[0]): + for j in range(target_logits.shape[1]): + if torch.max(target_logits[i][j]) < 10: + target_logits[i][j][18] = 10 + + temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda") + predict_shape = (12,) + + bs = candidates.shape[0] + num_spec_step = 4 + num_draft_tokens = candidates.shape[1] + + predicts = torch.full( + predict_shape, -1, dtype=torch.int32, device="cuda" + ) # mutable + accept_index = torch.full( + (bs, num_spec_step), -1, dtype=torch.int32, device="cuda" + ) # mutable + accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable + + expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1) + target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) + draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda") + + coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32) + print(f"{candidates=}") + print(f"{retrive_index=}") + print(f"{retrive_next_token=}") + print(f"{retrive_next_sibling=}") + print(f"{coins=}") + + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + uniform_samples=coins, + target_probs=target_probs, + draft_probs=draft_probs, + deterministic=True, + ) + + print(f"{predicts=}") + print(f"{accept_index=}") + print(f"{accept_token_num=}") + + assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 3, 4, 5], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [3, 2] + + +if __name__ == "__main__": + test_tree_speculative_sampling_target_only() diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 647733203b6..6630dd6366a 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3.post1" +__version__ = "0.0.3.post2"