From ddc4169f3aeddf2d3ee2d5ba5228ccf5d94b8379 Mon Sep 17 00:00:00 2001 From: b4b4o Date: Sat, 9 Mar 2024 22:13:58 +0800 Subject: [PATCH] feat: add invokeBatchTopK and invokeMedusaBatchMatch --- src/turbomind/kernels/reduce_kernel_utils.cuh | 20 ++ .../kernels/sampling_topk_kernels.cu | 176 ++++++++++++++++++ src/turbomind/kernels/sampling_topk_kernels.h | 12 ++ src/turbomind/models/llama/llama_kernels.cu | 66 +++++++ src/turbomind/models/llama/llama_kernels.h | 9 + .../models/medusa_plugin/medusa_head.cc | 51 ++++- .../models/medusa_plugin/medusa_head.h | 5 +- 7 files changed, 333 insertions(+), 6 deletions(-) diff --git a/src/turbomind/kernels/reduce_kernel_utils.cuh b/src/turbomind/kernels/reduce_kernel_utils.cuh index 614aa684fa..fd96575fe0 100644 --- a/src/turbomind/kernels/reduce_kernel_utils.cuh +++ b/src/turbomind/kernels/reduce_kernel_utils.cuh @@ -336,6 +336,26 @@ struct TopK_2 { } }; +template<> +struct TopK_2 { + int p = -1; + int u = std::numeric_limits::min(); + + __device__ __forceinline__ void insert(int elem, int elem_id) + { + if (elem > u) { + u = elem; + p = elem_id; + } + } + + __device__ __forceinline__ void init() + { + p = -1; + u = std::numeric_limits::min(); + } +}; + template __device__ __forceinline__ TopK_2 reduce_topk_op_2(const TopK_2& a, const TopK_2& b) { diff --git a/src/turbomind/kernels/sampling_topk_kernels.cu b/src/turbomind/kernels/sampling_topk_kernels.cu index 82b208298d..2ab7e1502d 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.cu +++ b/src/turbomind/kernels/sampling_topk_kernels.cu @@ -619,4 +619,180 @@ template void invokeTopKTopPSampling(void* workspace, const int* end_ids, cudaStream_t stream); +template +__global__ void topk_only(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size, + const int* end_ids, + const bool* skip_decode) +{ + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int batch_id = bid; + if (skip_decode != nullptr && skip_decode[batch_id]) { + return; + } + const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; + const int tmp_log_buf_index = batch_id * vocab_size; + const int tmp_topk_buf_index = batch_id * max_top_k; + + TopK_2 partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + + for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + for (int ite = 0; ite < k; ite++) { + partial.init(); +#pragma unroll + for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p % vocab_size; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -MAX_T_VAL; + } + __syncthreads(); + } +} + +#ifdef _MSC_VER +#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ + if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \ + topk_only<<>>(log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + max_top_k, \ + top_ks, \ + vocab_size, \ + end_ids, \ + skip_decode); \ + break; \ + } +#else +#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ + case K_MIN ... K_MAX: \ + topk_only<<>>(log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + max_top_k, \ + top_ks, \ + vocab_size, \ + end_ids, \ + skip_decode); \ + break; +#endif + +template +void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const T* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode) +{ + + const int vocab_size = vocab_size_padded; + int temp_log_probs_buf_size = batch_size * vocab_size; + int topk_tmp_ids_buf_size = batch_size * max_top_k; + int topk_tmp_val_buf_size = batch_size * max_top_k; + + temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + + if (workspace == nullptr) { + workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size + + sizeof(T) * topk_tmp_val_buf_size; + return; + } + + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); +#ifdef _MSC_VER + do { + ONLY_TOPK_CASE_K(1, 16, 128 * 8); + ONLY_TOPK_CASE_K(17, 32, 256 * 8); + ONLY_TOPK_CASE_K(33, 64, 256 * 8); + ONLY_TOPK_CASE_K(65, 1024, 256 * 8); + throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); + } while (0); +#else + switch (max_top_k) { + ONLY_TOPK_CASE_K(1, 16, 128 * 8); + ONLY_TOPK_CASE_K(17, 32, 256 * 8); + ONLY_TOPK_CASE_K(33, 64, 256 * 8); + ONLY_TOPK_CASE_K(65, 1024, 256 * 8); + default: + throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); + } +#endif +} + +#undef ONLY_TOPK_CASE_K + +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const half* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const float* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); + +#ifdef ENABLE_BF16 +template void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const __nv_bfloat16* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); +#endif } // namespace turbomind diff --git a/src/turbomind/kernels/sampling_topk_kernels.h b/src/turbomind/kernels/sampling_topk_kernels.h index a539abf0fa..8530e8f5eb 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.h +++ b/src/turbomind/kernels/sampling_topk_kernels.h @@ -95,4 +95,16 @@ void invokeTopKTopPSampling(void* workspace, const int* end_ids, cudaStream_t stream); +template +void invokeBatchTopKOnly(void* workspace, + size_t& workspace_size, + const T* log_probs, + bool* finished, + const int max_top_k, + const int* top_ks, + const int vocab_size_padded, + const int* end_ids, + cudaStream_t stream, + const int batch_size, + const bool* skip_decode); } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index 6444a52602..31f450f05f 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -905,4 +905,70 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud }); } +template +__global__ void medusaBatchedMatchKernel(const int* __restrict__ input_ids, + const int* __restrict__ output_ids, + int* match_idx, + int* match_length, + const int path_num, + int size) +{ + //[b, path_num, 1 + head_num] + const int length = size + 1; + const int limit_r = gridDim.x * path_num * length; + const int bid = blockIdx.x; // (0, batch_size) + const int tid = threadIdx.x; // (0, BLOCK_SIZE_) + + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK_2 partial; + partial.init(); + + for (int idx = tid; idx < path_num; idx += BLOCK_SIZE_) { + int start_id = bid * path_num * length + idx * length; // belong to (bid, path_id) + int accumulate_length = 0; + for (int i = 0; i < size && (start_id + i) < limit_r; ++i) { + if (input_ids[start_id + i + 1] == output_ids[start_id + i]) { + ++accumulate_length; + } + else { + break; + } + } + partial.insert(accumulate_length, idx); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + const int index = bid; + match_idx[index] = total.p; + match_length[index] = total.u; + } + __syncthreads(); +} + +void invokeMedusaBatchedMatchKernel(const int* input_ids, + const int* output_ids, + int* max_match_idx, + int* max_match_length, + int batch_size, + int path_num, + int medusa_head_num, + cudaStream_t stream) +{ + // inputs: + // input_ids: [batch_size, path_num, 1 + medusa_head_num] + // output_ids: [batch_size, path_num, 1 + medusa_head_num] + // outputs: + // max_match_idx: [batch_size] + // max_match_length: [batch_size] + dim3 grid, block; + grid.x = batch_size; + block.x = 64; + medusaBatchedMatchKernel<64><<>>( + input_ids, output_ids, max_match_idx, max_match_length, path_num, medusa_head_num); +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 3b01dee60d..cd10939bcc 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -167,4 +167,13 @@ inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_ TM_LOG_ERROR("--------> rank = %d, step = %d, seq_len = %d <--------", tp_rank, step, h_seq_len); } +void invokeMedusaBatchedMatchKernel(const int* input_ids, + const int* output_ids, + int* max_match_idx, + int* max_match_length, + int batch_size, + int path_num, + int medusa_head_num, + cudaStream_t stream); + } // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/medusa_head.cc b/src/turbomind/models/medusa_plugin/medusa_head.cc index d5c9652e53..9febb22645 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.cc +++ b/src/turbomind/models/medusa_plugin/medusa_head.cc @@ -3,6 +3,7 @@ // Zhiwei Bao #include "src/turbomind/models/medusa_plugin/medusa_head.h" +#include "src/turbomind/kernels/sampling_topk_kernels.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/cublasMMWrapper.h" @@ -36,14 +37,18 @@ void MedusaHead::forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight& medusa_weight) { - const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; - const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); - T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr(); + const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; + const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); + int* h_topk_output_ids = output_tensors->at("medusa_head_output").getPtr(); + + allocate_buffer(batch_size); // TODO parallelize this loop for (int i = 0; i < medusa_num_heads_; i++) { - T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_; + T* medusa_head_logits = medusa_head_logits_buf_ + i * batch_size * vocab_size_; forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i); } + + top_k(h_topk_output_ids, medusa_head_logits_buf_, batch_size * medusa_num_heads_); } template @@ -53,7 +58,6 @@ void MedusaHead::forward(T* medusa_head_output, const MedusaWeight& medusa_weight, int head_id) { - allocate_buffer(batch_size); // TODO support multi medusa_num_layers resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]); linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]); @@ -72,6 +76,8 @@ void MedusaHead::allocate_buffer(size_t batch_size) { resblock_buf_ = (T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false); + medusa_head_logits_buf_ = (T*)allocator_->reMalloc( + medusa_head_logits_buf_, medusa_num_heads_ * sizeof(T) * batch_size * vocab_size_, false); is_allocated_buffer_ = true; } @@ -80,10 +86,45 @@ void MedusaHead::free_buffer() { if (is_free_buffer_after_forward_ && is_allocated_buffer_) { allocator_->free((void**)&resblock_buf_); + allocator_->free((void**)&workspace_buf_); + allocator_->free((void**)&medusa_head_logits_buf_); is_allocated_buffer_ = false; } } +template +void MedusaHead::top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k) +{ + size_t workspace_size_now = 0; + invokeBatchTopKOnly(nullptr, + workspace_size_now, + d_input_logits, + nullptr, + k, + nullptr, + vocab_size_, + nullptr, + stream_, + batch_size, + nullptr); + workspace_buf_ = (void*)allocator_->reMalloc(workspace_buf_, workspace_size_now, false); + invokeBatchTopKOnly(workspace_buf_, + workspace_size_now, + d_input_logits, + nullptr, + k, + nullptr, + vocab_size_, + nullptr, + stream_, + batch_size, + nullptr); + int offset = (int)(ceil(batch_size * vocab_size_ / 4.)) * 4; + int output_size = (int)(ceil(batch_size * k / 4.)) * 4; + int* topk_output_ids = (int*)(((T*)workspace_buf_) + offset); + cudaMemcpy(h_topk_output_ids, topk_output_ids, sizeof(int) * output_size, cudaMemcpyDeviceToHost); +} + template class MedusaHead; template class MedusaHead; #ifdef ENABLE_BF16 diff --git a/src/turbomind/models/medusa_plugin/medusa_head.h b/src/turbomind/models/medusa_plugin/medusa_head.h index 44d51035ed..62c7f618e4 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.h +++ b/src/turbomind/models/medusa_plugin/medusa_head.h @@ -38,6 +38,7 @@ class MedusaHead { private: void allocate_buffer(size_t batch_size); void free_buffer(); + void top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k = 1); private: size_t in_size_; @@ -47,7 +48,9 @@ class MedusaHead { std::unique_ptr> resblock_; std::unique_ptr> linear_; - T* resblock_buf_; + T* resblock_buf_; + void* workspace_buf_; + T* medusa_head_logits_buf_; cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_;