From 2a27648efc201a93f5825f7e8b1f2ee7b7d1db28 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 10 Mar 2024 22:00:19 +0800 Subject: [PATCH] sync --- .../turbomind/deploy/target_model/base.py | 2 + .../kernels/sampling_topk_kernels.cu | 121 +++----- src/turbomind/kernels/sampling_topk_kernels.h | 18 +- src/turbomind/models/llama/CMakeLists.txt | 3 +- src/turbomind/models/llama/LlamaBatch.cc | 250 +++++++++++++++- src/turbomind/models/llama/LlamaBatch.h | 33 +- src/turbomind/models/llama/LlamaV2.cc | 60 +++- src/turbomind/models/llama/LlamaV2.h | 14 +- src/turbomind/models/llama/LlamaWeight.cc | 29 +- src/turbomind/models/llama/LlamaWeight.h | 10 +- src/turbomind/models/llama/SequenceManager.h | 2 + src/turbomind/models/llama/llama_kernels.cu | 22 +- src/turbomind/models/llama/llama_kernels.h | 15 +- .../models/medusa_plugin/CMakeLists.txt | 6 +- .../models/medusa_plugin/medusa_head.cc | 26 +- .../models/medusa_plugin/medusa_head.h | 6 +- .../models/medusa_plugin/medusa_weight.cc | 4 +- .../models/medusa_plugin/medusa_weight.h | 8 +- .../tests/medusa_head_example.cc | 281 ++++++++++++++++++ .../triton_backend/llama/LlamaTritonModel.cc | 13 +- .../triton_backend/llama/LlamaTritonModel.h | 3 + 21 files changed, 759 insertions(+), 167 deletions(-) create mode 100644 src/turbomind/models/medusa_plugin/tests/medusa_head_example.cc diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 5c6cc415d5..9113658eb8 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -67,6 +67,8 @@ class TurbomindModelConfig: rope_scaling_factor: float = 0.0 use_dynamic_ntk: int = 0 use_logn_attn: int = 0 + medusa_num_heads: int = 0 + medusa_num_layers: int = 0 @classmethod def from_dict(cls, env, allow_none=False): diff --git a/src/turbomind/kernels/sampling_topk_kernels.cu b/src/turbomind/kernels/sampling_topk_kernels.cu index 2ab7e1502d..e87f1c6a40 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.cu +++ b/src/turbomind/kernels/sampling_topk_kernels.cu @@ -620,16 +620,12 @@ template void invokeTopKTopPSampling(void* workspace, cudaStream_t stream); template -__global__ void topk_only(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size, - const int* end_ids, - const bool* skip_decode) +__global__ void topk(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int max_top_k, + const int vocab_size) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -637,11 +633,8 @@ __global__ void topk_only(const T* __restrict log_probs, const int tid = threadIdx.x; const int bid = blockIdx.x; - const int batch_id = bid; - if (skip_decode != nullptr && skip_decode[batch_id]) { - return; - } - const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; + const int batch_id = bid; + const int k = max_top_k; const int tmp_log_buf_index = batch_id * vocab_size; const int tmp_topk_buf_index = batch_id * max_top_k; @@ -676,46 +669,26 @@ __global__ void topk_only(const T* __restrict log_probs, #ifdef _MSC_VER #define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \ - topk_only<<>>(log_probs, \ - temp_log_probs, \ - topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - finished, \ - max_top_k, \ - top_ks, \ - vocab_size, \ - end_ids, \ - skip_decode); \ + topk<<>>( \ + log_probs, temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, max_top_k, vocab_size); \ break; \ } #else #define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \ case K_MIN ... K_MAX: \ - topk_only<<>>(log_probs, \ - temp_log_probs, \ - topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - finished, \ - max_top_k, \ - top_ks, \ - vocab_size, \ - end_ids, \ - skip_decode); \ + topk<<>>( \ + log_probs, temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, max_top_k, vocab_size); \ break; #endif template -void invokeBatchTopKOnly(void* workspace, - size_t& workspace_size, - const T* log_probs, - bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size, - const bool* skip_decode) +void invokeBatchTopK(void* workspace, + size_t& workspace_size, + const T* log_probs, + const int max_top_k, + const int vocab_size_padded, + cudaStream_t stream, + const int batch_size) { const int vocab_size = vocab_size_padded; @@ -758,41 +731,29 @@ void invokeBatchTopKOnly(void* workspace, #undef ONLY_TOPK_CASE_K -template void invokeBatchTopKOnly(void* workspace, - size_t& workspace_size, - const half* log_probs, - bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size, - const bool* skip_decode); - -template void invokeBatchTopKOnly(void* workspace, - size_t& workspace_size, - const float* log_probs, - bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size, - const bool* skip_decode); +template void invokeBatchTopK(void* workspace, + size_t& workspace_size, + const half* log_probs, + const int max_top_k, + const int vocab_size_padded, + cudaStream_t stream, + const int batch_size); + +template void invokeBatchTopK(void* workspace, + size_t& workspace_size, + const float* log_probs, + const int max_top_k, + const int vocab_size_padded, + cudaStream_t stream, + const int batch_size); #ifdef ENABLE_BF16 -template void invokeBatchTopKOnly(void* workspace, - size_t& workspace_size, - const __nv_bfloat16* log_probs, - bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size, - const bool* skip_decode); +template void invokeBatchTopK(void* workspace, + size_t& workspace_size, + const __nv_bfloat16* log_probs, + const int max_top_k, + const int vocab_size_padded, + cudaStream_t stream, + const int batch_size); #endif } // namespace turbomind diff --git a/src/turbomind/kernels/sampling_topk_kernels.h b/src/turbomind/kernels/sampling_topk_kernels.h index 8530e8f5eb..f1129bf924 100644 --- a/src/turbomind/kernels/sampling_topk_kernels.h +++ b/src/turbomind/kernels/sampling_topk_kernels.h @@ -96,15 +96,11 @@ void invokeTopKTopPSampling(void* workspace, cudaStream_t stream); template -void invokeBatchTopKOnly(void* workspace, - size_t& workspace_size, - const T* log_probs, - bool* finished, - const int max_top_k, - const int* top_ks, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream, - const int batch_size, - const bool* skip_decode); +void invokeBatchTopK(void* workspace, + size_t& workspace_size, + const T* log_probs, + const int max_top_k, + const int vocab_size_padded, + cudaStream_t stream, + const int batch_size); } // namespace turbomind diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 1e5889839c..8d5970ba65 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -37,7 +37,8 @@ target_link_libraries(Llama PUBLIC CUDA::cudart nccl_utils cuda_utils logger - llama_fmha) + llama_fmha + Medusa) add_executable(llama_gemm llama_gemm.cc) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 8a49005100..875b76de72 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -245,6 +245,9 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) FT_CHECK(state.sequences[idx]); auto& seq = *state.sequences[idx]; + if (medusa_enable_) { + seq.iter = 0; + } if (int step = r->inputs[rank_].getVal("step", -1); step >= 0) { if (step <= seq.tokens.size()) { @@ -734,6 +737,29 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len) rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false); + medusa_inited_seq_hidden_states_buf_ = + (T*)allocator_->reMalloc(medusa_inited_seq_hidden_states_buf_, + sizeof(T) * max_batch_size_ * (1 + medusa_num_heads_) * hidden_units, + false); + medusa_new_seq_last_hidden_state_buf_ = (T*)allocator_->reMalloc( + medusa_new_seq_last_hidden_state_buf_, sizeof(T) * max_batch_size_ * hidden_units, false); + medusa_inited_seq_verified_last_hidden_state_buf_ = (T*)allocator_->reMalloc( + medusa_inited_seq_verified_last_hidden_state_buf_, sizeof(T) * max_batch_size_ * hidden_units, false); + medusa_inited_input_ids_buf_ = (int*)allocator_->reMalloc( + medusa_inited_input_ids_buf_, sizeof(int) * max_batch_size_ * (1 + medusa_num_heads_), false); + + medusa_output_ids_buf_ = (int*)allocator_->reMalloc( + medusa_output_ids_buf_, sizeof(int) * max_batch_size_ * (1 + medusa_num_heads_), false); + medusa_verified_last_output_ids_buf_ = (int*)allocator_->reMalloc( + medusa_verified_last_output_ids_buf_, sizeof(int) * max_batch_size_ * (1 + medusa_num_heads_), false); + medusa_end_ids_buf_ = + (int*)allocator_->reMalloc(medusa_end_ids_buf_, sizeof(int) * max_batch_size_ * (1 + medusa_num_heads_), false); + deviceFill(medusa_end_ids_buf_, max_batch_size_ * (1 + medusa_num_heads_), model_->end_id_, stream_); + + max_match_length_buf_ = (int*)allocator_->reMalloc(max_match_length_buf_, sizeof(int) * max_batch_size_, false); + h_max_match_length_buf_ = + (int*)allocator_->reMalloc(h_max_match_length_buf_, sizeof(int) * max_batch_size_, false, true); + is_allocate_buffer_ = true; } @@ -860,6 +886,24 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&rope_theta_); + allocator_->free((void**)&medusa_inited_seq_hidden_states_buf_); + allocator_->free((void**)&medusa_new_seq_last_hidden_state_buf_); + allocator_->free((void**)&medusa_inited_seq_verified_last_hidden_state_buf_); + allocator_->free((void**)&medusa_inited_input_ids_buf_); + + if (medusa_logits_buf_) { + allocator_->free((void**)&medusa_logits_buf_); + } + if (medusa_local_logits_buf_) { + allocator_->free((void**)&medusa_local_logits_buf_); + } + allocator_->free((void**)&medusa_output_ids_buf_); + allocator_->free((void**)&medusa_verified_last_output_ids_buf_); + allocator_->free((void**)&medusa_end_ids_buf_); + + allocator_->free((void**)&max_match_length_buf_); + allocator_->free((void**)&h_max_match_length_buf_); + is_allocate_buffer_ = false; } @@ -894,7 +938,8 @@ void LlamaBatch::FreeBuffer() } template -LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model): +LlamaBatch::LlamaBatch( + const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model, int medusa_num_heads): max_batch_size_(params.max_batch_size), max_context_token_num_(params.max_context_token_num), session_len_(params.session_len), @@ -905,7 +950,9 @@ LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i data_type_(getTensorType()), num_tokens_per_iter_(params.num_tokens_per_iter), extra_tokens_per_iter_(params.extra_tokens_per_iter), - max_prefill_iters_(params.max_prefill_iters) + max_prefill_iters_(params.max_prefill_iters), + medusa_num_heads_(medusa_num_heads), + medusa_enable_(medusa_num_heads != 0) { stream_ = model_->stream_; allocator_ = model_->allocator_; @@ -1449,6 +1496,10 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) int pf_offset = -1; std::vector input_d_ptrs(active_size); + medusa_state_vec_.resize(active_size); + int inited_index = 0; + int new_index = 0; + if (iter == 0) { // The first iter may have pre-fill tokens for (int i = 0; i < active_size; ++i) { const auto& seq = *state_->sequences[i]; @@ -1459,6 +1510,20 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) if (seq.input_length > 1 && pf_offset < 0) { pf_offset = i; } + if (medusa_enable_) { + auto& medusa_state = medusa_state_vec_[i]; + medusa_state.len = seq.input_length; + if (seq.iter == 0) { + medusa_state.verified_len = seq.input_length; + medusa_state.inited = false; + medusa_state.index = new_index++; + } + else { + medusa_state.verified_len = 0; + medusa_state.inited = true; + medusa_state.index = inited_index++; + } + } } if (pf_offset < 0) { pf_offset = active_size; @@ -1472,6 +1537,10 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) pf_offset = active_size; } + if (medusa_enable_) { + pf_offset = 0; + } + // These buffers are only accessed when there are prefill workloads if (pf_offset != active_size) { Copy(state_->h_context_length, active_size, context_length_buf_); @@ -1525,8 +1594,12 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) batched_copy.Submit(stream_); - const int dc_batch_size = p ? 0 : pf_offset; - const int pf_batch_size = mini_batch_size - dc_batch_size; + int dc_batch_size = p ? 0 : pf_offset; + int pf_batch_size = mini_batch_size - dc_batch_size; + if (medusa_enable_) { + dc_batch_size = 0; + pf_batch_size = mini_batch_size; + } if (rank_ == 0) { if (pf_batch_size) { @@ -1557,20 +1630,174 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) dc_batch_size, pf_batch_size, sequences.data()); - if (iter == 0) { - // compute logits of inputs if requested - OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences); + if (medusa_enable_) { + // FIXME enable OutputContextLogits + + T* context_decoder_output_src = context_decoder_output_buf_; + + int* context_decoder_ids_src = context_decoder_ids_buf_; + + for (int i = 0; i < mini_batch_size; i++) { + int global_index = i + first; + int len = medusa_state_vec_[global_index].len; + int hidden_units = model_->hidden_units_; + bool inited = medusa_state_vec_[global_index].inited; + int index = medusa_state_vec_[global_index].index; + + if (inited) { + T* context_decoder_output_dst = + medusa_inited_seq_hidden_states_buf_ + index * len * hidden_units; + Copy(context_decoder_output_src, len * hidden_units, context_decoder_output_dst); + + int* context_decoder_ids_dst = medusa_inited_input_ids_buf_ + index * len; + Copy(context_decoder_ids_src, len, context_decoder_ids_dst); + } + else { + T* context_decoder_output_dst = medusa_new_seq_last_hidden_state_buf_ + index * hidden_units; + Copy(context_decoder_output_src + (len - 1) * hidden_units, + hidden_units, + context_decoder_output_dst); + } + + context_decoder_output_src += len * hidden_units; + context_decoder_ids_src += len; + } + } + else { + // compute logits of inputs if requested + OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences); + } + } + } + + if (medusa_enable_) { + // verification + if (inited_index != 0) { + if (medusa_logits_buf_ == nullptr) { + NcclGuard guard(model_->tensor_para_, stream_, true); + medusa_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ + * max_batch_size_ * (1 + medusa_num_heads_)); + const auto tp = model_->tensor_para_.world_size_; + if (tp > 1) { + FT_CHECK(model_->vocab_size_padded_ % tp == 0); + // FIXME + const auto local_vocab_size = model_->vocab_size_padded_ / tp; + medusa_local_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ + * max_batch_size_ * (1 + medusa_num_heads_)); + } + } + + int inited_token_num = inited_index * (1 + medusa_num_heads_); + + model_->postDecodeEmbedding( + medusa_logits_buf_, medusa_local_logits_buf_, medusa_inited_seq_hidden_states_buf_, inited_token_num); + + model_->batchDynamicDecode(medusa_output_ids_buf_, + medusa_logits_buf_, + medusa_end_ids_buf_, + state_->curand_state, + inited_token_num); + + invokeMedusaBatchMatch(medusa_inited_input_ids_buf_, + medusa_output_ids_buf_, + max_match_length_buf_, + inited_token_num, + 1, + medusa_num_heads_, + stream_); + int* h_max_match_length_dst = h_max_match_length_buf_; + Copy(max_match_length_buf_, inited_index, h_max_match_length_dst); + } + + // generation + if (inited_index != 0) { + T* medusa_inited_seq_hidden_states_src = medusa_inited_seq_hidden_states_buf_; + int* medusa_output_ids_src = medusa_output_ids_buf_; + T* medusa_inited_seq_verified_last_hidden_state_dst = medusa_inited_seq_verified_last_hidden_state_buf_; + int* medusa_verified_last_output_ids_dst = medusa_verified_last_output_ids_buf_; + for (int i = 0; i < inited_index; i++) { + medusa_inited_seq_verified_last_hidden_state_dst = + Copy(medusa_inited_seq_hidden_states_src + h_max_match_length_buf_[i], + 1, + medusa_inited_seq_verified_last_hidden_state_dst); + + // inited seq + medusa_verified_last_output_ids_dst = Copy( + medusa_output_ids_src + h_max_match_length_buf_[i] - 1, 1, medusa_verified_last_output_ids_dst); + + medusa_inited_seq_hidden_states_src += (1 + medusa_num_heads_) * model_->hidden_units_; + medusa_output_ids_src += 1 + medusa_num_heads_; + } + } + + // new seq + if (new_index != 0) { + if (medusa_logits_buf_ == nullptr) { + NcclGuard guard(model_->tensor_para_, stream_, true); + medusa_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ + * max_batch_size_ * (1 + medusa_num_heads_)); + const auto tp = model_->tensor_para_.world_size_; + if (tp > 1) { + FT_CHECK(model_->vocab_size_padded_ % tp == 0); + // FIXME + const auto local_vocab_size = model_->vocab_size_padded_ / tp; + medusa_local_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ + * max_batch_size_ * (1 + medusa_num_heads_)); + } + } + model_->postDecodeEmbedding( + medusa_logits_buf_, medusa_local_logits_buf_, medusa_new_seq_last_hidden_state_buf_, new_index); + model_->batchDynamicDecode( + medusa_output_ids_buf_, medusa_logits_buf_, medusa_end_ids_buf_, state_->curand_state, new_index); + std::vector new_seq_lm_head_output_ids(new_index, 0.0); + Copy(medusa_output_ids_buf_, new_index, new_seq_lm_head_output_ids.data()); + } + + if (inited_index != 0) { + std::vector inited_seq_lm_head_output_ids(inited_index, 0.0); + Copy(medusa_verified_last_output_ids_buf_, inited_index, inited_seq_lm_head_output_ids.data()); + } + + if (new_index != 0) { + std::vector new_seq_topk_output_ids(medusa_num_heads_ * new_index, 0.0); + int* d_new_seq_topk_output_ids = nullptr; + d_new_seq_topk_output_ids = (int*)allocator_->reMalloc( + d_new_seq_topk_output_ids, sizeof(int) * medusa_num_heads_ * new_index, false); + model_->medusaForward(d_new_seq_topk_output_ids, medusa_new_seq_last_hidden_state_buf_, new_index); + Copy(d_new_seq_topk_output_ids, medusa_num_heads_ * new_index, new_seq_topk_output_ids.data()); + allocator_->free((void**)&d_new_seq_topk_output_ids); + } + + if (inited_index != 0) { + std::vector inited_seq_topk_output_ids(medusa_num_heads_ * inited_index, 0.0); + int* d_inited_seq_topk_output_ids = nullptr; + d_inited_seq_topk_output_ids = (int*)allocator_->reMalloc( + d_inited_seq_topk_output_ids, sizeof(int) * medusa_num_heads_ * inited_index, false); + model_->medusaForward( + d_inited_seq_topk_output_ids, medusa_inited_seq_verified_last_hidden_state_buf_, inited_index); + Copy(d_inited_seq_topk_output_ids, medusa_num_heads_ * inited_index, inited_seq_topk_output_ids.data()); + allocator_->free((void**)&d_inited_seq_topk_output_ids); } } std::fill(h_input_length_buf_, h_input_length_buf_ + active_size, 0); // `SequenceManager` needs real-time value of cache length + int j = 0; for (int i = 0; i < active_size; ++i) { if (state_->requests[i]) { FT_CHECK(state_->sequences[i]); - state_->sequences[i]->cache_len += state_->sequences[i]->input_length; + if (medusa_enable_) { + state_->sequences[i]->iter += 1; + if (medusa_state_vec_[i].inited) { + medusa_state_vec_[i].verified_len = h_max_match_length_buf_[j++]; + } + state_->sequences[i]->cache_len += medusa_state_vec_[i].verified_len; + } + else { + state_->sequences[i]->cache_len += state_->sequences[i]->input_length; + } } } @@ -1625,6 +1852,13 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) return !should_stop; } +std::ostream& operator<<(std::ostream& os, const MedusaState& medusa_state) +{ + os << "index=" << medusa_state.index << " len=" << medusa_state.len << " verified_len=" << medusa_state.verified_len + << " inited=" << medusa_state.inited; + return os; +} + template class LlamaBatch; template class LlamaBatch; #ifdef ENABLE_BF16 diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 604cf8b0a3..2b915216b0 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -15,6 +15,7 @@ #include "src/turbomind/utils/cuda_utils.h" #include #include +#include #include namespace turbomind { @@ -40,6 +41,15 @@ struct BatchState { int size; }; +struct MedusaState { + int index; + int len; + int verified_len; + bool inited; + + friend std::ostream& operator<<(std::ostream& os, const MedusaState& medusa_state); +}; + template class LlamaV2; @@ -95,7 +105,8 @@ class LlamaBatch { const std::vector& lengths, const std::vector& sequences); - explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model); + explicit LlamaBatch( + const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2* model, int medusa_num_heads); ~LlamaBatch() { @@ -292,6 +303,26 @@ class LlamaBatch { const int num_tokens_per_iter_; const int extra_tokens_per_iter_; const int max_prefill_iters_; + + int medusa_num_heads_ = 0; + bool medusa_enable_ = false; + std::vector medusa_state_vec_; + + T* medusa_inited_seq_hidden_states_buf_{}; + T* medusa_new_seq_last_hidden_state_buf_{}; + T* medusa_inited_seq_verified_last_hidden_state_buf_{}; + + int* medusa_inited_input_ids_buf_{}; + + float* medusa_logits_buf_{}; + float* medusa_local_logits_buf_{}; + + int* medusa_output_ids_buf_{}; + int* medusa_end_ids_buf_{}; + int* medusa_verified_last_output_ids_buf_{}; + + int* max_match_length_buf_{}; + int* h_max_match_length_buf_{}; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index e8bc5fd477..4545998965 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -32,12 +32,15 @@ #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/unified_decoder.h" +#include "src/turbomind/models/medusa_plugin/medusa_head.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #include #include #include +#include +#include namespace turbomind { @@ -63,7 +66,9 @@ LlamaV2::LlamaV2(size_t head_num, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop): + cudaDeviceProp* cuda_device_prop, + int medusa_num_heads, + int medusa_num_layers): head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), @@ -85,8 +90,9 @@ LlamaV2::LlamaV2(size_t head_num, is_free_buffer_after_forward_(is_free_buffer_after_forward), cuda_device_prop_(cuda_device_prop), debug_(isDebug()), - shared_state_(shared_state) - + shared_state_(shared_state), + medusa_num_heads_(medusa_num_heads), + medusa_num_layers_(medusa_num_layers) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); @@ -94,7 +100,7 @@ LlamaV2::LlamaV2(size_t head_num, vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_; - batch_ = std::make_unique>(engine_params, cache_block_seq_len, quant_policy, this); + batch_ = std::make_unique>(engine_params, cache_block_seq_len, quant_policy, this, medusa_num_heads_); initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy); @@ -102,6 +108,11 @@ LlamaV2::LlamaV2(size_t head_num, /// TODO: decouple Llama model and batch inference batch_->Start(); + + if (medusa_num_heads_ != 0) { + medusa_head_ = std::make_unique>( + hidden_units_, vocab_size_, medusa_num_heads_, stream_, cublas_wrapper_, allocator_, tensor_para_, false); + } } template @@ -498,6 +509,47 @@ void LlamaV2::forward(std::unordered_map* outputs, } } +template +void LlamaV2::medusaForward(int* topk_output_ids, const T* input_buf, const size_t batch_size) +{ + turbomind::DataType dtype = turbomind::getTensorType(); + + turbomind::TensorMap inputs{ + {"medusa_head_input", {turbomind::MEMORY_GPU, dtype, {batch_size, hidden_units_}, input_buf}}, + }; + + turbomind::TensorMap outputs{ + {"medusa_head_output", + {turbomind::MEMORY_GPU, dtype, {(size_t)medusa_num_heads_, batch_size, 1}, topk_output_ids}}, + }; + + medusa_head_->forward(&outputs, &inputs, weights_->get_medusa_weight()); +} + +template +void LlamaV2::batchDynamicDecode( + int* output_ids, const float* logits, const int* end_ids, curandState_t* curand_state, size_t batch_size) +{ + const int step = 0; + const int max_input_length = 0; + const int ite = 0; + std::unordered_map input_tensors{ + {"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, 1, vocab_size_}, logits}}, + {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"end_id", {MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}}, + {"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}}, + {"local_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &batch_size}}}; + + std::unordered_map output_tensors{ + {"output_ids", {MEMORY_GPU, TYPE_INT32, {1, batch_size, 1}, output_ids}}, + {"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}}; + + TensorMap runtime_args; + dynamic_decode_layer_->setup(batch_size, 1, &runtime_args); + dynamic_decode_layer_->forward(&output_tensors, &input_tensors); +} + template class LlamaV2; template class LlamaV2; #ifdef ENABLE_BF16 diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index c770f24289..8e611c7ece 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -29,6 +29,7 @@ #include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/unified_decoder.h" +#include "src/turbomind/models/medusa_plugin/medusa_head.h" #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/instance_comm.h" @@ -75,7 +76,9 @@ class LlamaV2 { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - cudaDeviceProp* cuda_device_prop); + cudaDeviceProp* cuda_device_prop, + int medusa_num_heads = 0, + int medusa_num_layers = 0); struct Control { AbstractInstanceComm* comm; @@ -145,6 +148,10 @@ class LlamaV2 { size_t token_ids_len, size_t batch_size); + void medusaForward(int* topk_output_ids, const T* input_buf, const size_t batch_size); + void batchDynamicDecode( + int* token_ids, const float* logits, const int* end_ids, curandState_t* curand_state, size_t batch_size); + private: friend class LlamaBatch; @@ -184,6 +191,11 @@ class LlamaV2 { std::shared_ptr shared_state_; ffi_api_lock_ctrl_t ffi_lock_; std::unique_ptr> batch_; + + int medusa_num_heads_ = 0; + int medusa_num_layers_ = 0; + + std::unique_ptr> medusa_head_; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 6e62eaf420..6a6646aa36 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -19,6 +19,7 @@ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptWeight.cc #include "src/turbomind/models/llama/LlamaWeight.h" +#include "src/turbomind/models/medusa_plugin/medusa_weight.h" namespace turbomind { @@ -33,7 +34,9 @@ LlamaWeight::LlamaWeight(size_t head_num, WeightType weight_type, int group_size, size_t tensor_para_size, - size_t tensor_para_rank): + size_t tensor_para_rank, + int medusa_num_heads, + int medusa_num_layers): hidden_units_(head_num * size_per_head), inter_size_(inter_size), vocab_size_(vocab_size), @@ -41,7 +44,8 @@ LlamaWeight::LlamaWeight(size_t head_num, num_layer_(num_layer), weight_type_(weight_type), tensor_para_size_(tensor_para_size), - tensor_para_rank_(tensor_para_rank) + tensor_para_rank_(tensor_para_rank), + medusa_enable_(medusa_num_heads != 0) { if (vocab_size_padded_ % tensor_para_size_ != 0) { vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; @@ -61,6 +65,15 @@ LlamaWeight::LlamaWeight(size_t head_num, } mallocWeights(); + if (medusa_enable_) { + medusa_weight = std::make_unique>(medusa_num_heads, + medusa_num_layers, + hidden_units_, + vocab_size_, + weight_type, + tensor_para_size, + tensor_para_rank); + } } template @@ -90,7 +103,7 @@ template void LlamaWeight::loadModel(std::string dir_path) { FtCudaDataType model_file_type = FtCudaDataType::FP16; - if(weight_type_ == WeightType::kBF16){ + if (weight_type_ == WeightType::kBF16) { model_file_type = FtCudaDataType::BF16; } dir_path += '/'; @@ -110,6 +123,10 @@ void LlamaWeight::loadModel(std::string dir_path) for (unsigned layer = 0; layer < num_layer_; ++layer) { decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); } + + if (medusa_enable_) { + medusa_weight->load_model(dir_path, model_file_type); + } } template @@ -141,6 +158,12 @@ TensorMap LlamaWeight::getParams() return output; } +template +const MedusaWeight& LlamaWeight::get_medusa_weight() const +{ + return *medusa_weight.get(); +} + template struct LlamaWeight; template struct LlamaWeight; #ifdef ENABLE_BF16 diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index a896a87a09..b5d1f4da21 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -21,6 +21,7 @@ #pragma once #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" +#include "src/turbomind/models/medusa_plugin/medusa_weight.h" #include "src/turbomind/utils/memory_utils.h" namespace turbomind { @@ -38,7 +39,9 @@ struct LlamaWeight { WeightType weight_type, int group_size, size_t tensor_para_size, - size_t tensor_para_rank); + size_t tensor_para_rank, + int medusa_num_heads = 0, + int medusa_num_layers = 0); ~LlamaWeight(); @@ -54,6 +57,11 @@ struct LlamaWeight { const T* output_norm_weight{}; const T* post_decoder_embedding_kernel{}; + std::unique_ptr> medusa_weight; + bool medusa_enable_ = false; + + const MedusaWeight& get_medusa_weight() const; + private: void mallocWeights(); diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index a6c341fa6d..e73a69d2e6 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -40,6 +40,8 @@ struct Sequence { explicit Sequence(uint64_t _id): id(_id) {} friend std::ostream& operator<<(std::ostream& os, const Sequence& seq); + + mutable size_t iter = 0; }; using Sequences = std::vector; diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index 31f450f05f..4b935fc58c 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -908,7 +908,6 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud template __global__ void medusaBatchedMatchKernel(const int* __restrict__ input_ids, const int* __restrict__ output_ids, - int* match_idx, int* match_length, const int path_num, int size) @@ -943,32 +942,29 @@ __global__ void medusaBatchedMatchKernel(const int* __restrict__ input_ids, if (tid == 0) { const int index = bid; - match_idx[index] = total.p; match_length[index] = total.u; } __syncthreads(); } -void invokeMedusaBatchedMatchKernel(const int* input_ids, - const int* output_ids, - int* max_match_idx, - int* max_match_length, - int batch_size, - int path_num, - int medusa_head_num, - cudaStream_t stream) +void invokeMedusaBatchMatch(const int* input_ids, + const int* output_ids, + int* max_match_length, + int batch_size, + int path_num, + int medusa_head_num, + cudaStream_t stream) { // inputs: // input_ids: [batch_size, path_num, 1 + medusa_head_num] // output_ids: [batch_size, path_num, 1 + medusa_head_num] // outputs: - // max_match_idx: [batch_size] // max_match_length: [batch_size] dim3 grid, block; grid.x = batch_size; block.x = 64; - medusaBatchedMatchKernel<64><<>>( - input_ids, output_ids, max_match_idx, max_match_length, path_num, medusa_head_num); + medusaBatchedMatchKernel<64> + <<>>(input_ids, output_ids, max_match_length, path_num, medusa_head_num); } } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index cd10939bcc..ba12e3ccc6 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -167,13 +167,12 @@ inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_ TM_LOG_ERROR("--------> rank = %d, step = %d, seq_len = %d <--------", tp_rank, step, h_seq_len); } -void invokeMedusaBatchedMatchKernel(const int* input_ids, - const int* output_ids, - int* max_match_idx, - int* max_match_length, - int batch_size, - int path_num, - int medusa_head_num, - cudaStream_t stream); +void invokeMedusaBatchMatch(const int* input_ids, + const int* output_ids, + int* max_match_length, + int batch_size, + int path_num, + int medusa_head_num, + cudaStream_t stream); } // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/CMakeLists.txt b/src/turbomind/models/medusa_plugin/CMakeLists.txt index 2ab5eaebcf..ff586d26be 100644 --- a/src/turbomind/models/medusa_plugin/CMakeLists.txt +++ b/src/turbomind/models/medusa_plugin/CMakeLists.txt @@ -9,4 +9,8 @@ add_library(Medusa STATIC res_block.cc medusa_head.cc) -set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_executable(medusa_head_example tests/medusa_head_example.cc) +target_link_libraries(medusa_head_example PUBLIC Llama Medusa -lpthread) +install(TARGETS medusa_head_example DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin) diff --git a/src/turbomind/models/medusa_plugin/medusa_head.cc b/src/turbomind/models/medusa_plugin/medusa_head.cc index 9febb22645..c9273efbb9 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.cc +++ b/src/turbomind/models/medusa_plugin/medusa_head.cc @@ -13,7 +13,7 @@ namespace turbomind { template MedusaHead::MedusaHead(size_t in_size, size_t vocab_size, - size_t medusa_num_heads, + int medusa_num_heads, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, @@ -96,29 +96,9 @@ template void MedusaHead::top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k) { size_t workspace_size_now = 0; - invokeBatchTopKOnly(nullptr, - workspace_size_now, - d_input_logits, - nullptr, - k, - nullptr, - vocab_size_, - nullptr, - stream_, - batch_size, - nullptr); + invokeBatchTopK(nullptr, workspace_size_now, d_input_logits, k, vocab_size_, stream_, batch_size); workspace_buf_ = (void*)allocator_->reMalloc(workspace_buf_, workspace_size_now, false); - invokeBatchTopKOnly(workspace_buf_, - workspace_size_now, - d_input_logits, - nullptr, - k, - nullptr, - vocab_size_, - nullptr, - stream_, - batch_size, - nullptr); + invokeBatchTopK(workspace_buf_, workspace_size_now, d_input_logits, k, vocab_size_, stream_, batch_size); int offset = (int)(ceil(batch_size * vocab_size_ / 4.)) * 4; int output_size = (int)(ceil(batch_size * k / 4.)) * 4; int* topk_output_ids = (int*)(((T*)workspace_buf_) + offset); diff --git a/src/turbomind/models/medusa_plugin/medusa_head.h b/src/turbomind/models/medusa_plugin/medusa_head.h index 62c7f618e4..beeffc636b 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.h +++ b/src/turbomind/models/medusa_plugin/medusa_head.h @@ -17,8 +17,8 @@ template class MedusaHead { public: MedusaHead(size_t in_size, - size_t out_size, - size_t medusa_num_heads, + size_t vocab_size, + int medusa_num_heads, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, @@ -43,7 +43,7 @@ class MedusaHead { private: size_t in_size_; size_t vocab_size_; - size_t medusa_num_heads_; + int medusa_num_heads_; std::unique_ptr> resblock_; std::unique_ptr> linear_; diff --git a/src/turbomind/models/medusa_plugin/medusa_weight.cc b/src/turbomind/models/medusa_plugin/medusa_weight.cc index 04113fe29d..19dd2e26de 100644 --- a/src/turbomind/models/medusa_plugin/medusa_weight.cc +++ b/src/turbomind/models/medusa_plugin/medusa_weight.cc @@ -12,8 +12,8 @@ namespace turbomind { template -MedusaWeight::MedusaWeight(size_t medusa_num_heads, - size_t medusa_num_layers, +MedusaWeight::MedusaWeight(int medusa_num_heads, + int medusa_num_layers, size_t hidden_size, size_t vocab_size, WeightType weight_type, diff --git a/src/turbomind/models/medusa_plugin/medusa_weight.h b/src/turbomind/models/medusa_plugin/medusa_weight.h index c44fa0fa18..0fe6dbed0f 100644 --- a/src/turbomind/models/medusa_plugin/medusa_weight.h +++ b/src/turbomind/models/medusa_plugin/medusa_weight.h @@ -13,8 +13,8 @@ namespace turbomind { template class MedusaWeight { public: - MedusaWeight(size_t medusa_num_heads, - size_t medusa_num_layers, + MedusaWeight(int medusa_num_heads, + int medusa_num_layers, size_t hidden_size, size_t vocab_size, WeightType weight_type, @@ -38,8 +38,8 @@ class MedusaWeight { void load_bias(LlamaDenseWeight* weight, const std::string& path, FtCudaDataType model_file_type); private: - size_t medusa_num_heads_; - size_t medusa_num_layers_; + int medusa_num_heads_; + int medusa_num_layers_; size_t hidden_size_; size_t vocab_size_; WeightType weight_type_; diff --git a/src/turbomind/models/medusa_plugin/tests/medusa_head_example.cc b/src/turbomind/models/medusa_plugin/tests/medusa_head_example.cc new file mode 100644 index 0000000000..fb03df872a --- /dev/null +++ b/src/turbomind/models/medusa_plugin/tests/medusa_head_example.cc @@ -0,0 +1,281 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/models/medusa_plugin/medusa_head.h" +#include "src/turbomind/models/medusa_plugin/medusa_weight.h" +#include "src/turbomind/utils/cuda_utils.h" +#include +#include + +template +float T_to_float(T val) +{ + if (std::is_same::value) { + return __half2float((const half)val); + } + else if (std::is_same::value) { + return __bfloat162float((const __nv_bfloat16)val); + } +} + +template +std::pair get_type() +{ + turbomind::WeightType weight_type; + turbomind::FtCudaDataType model_file_type; + if (std::is_same::value) { + weight_type = turbomind::WeightType::kFP16; + model_file_type = turbomind::FtCudaDataType::FP16; + } + else if (std::is_same::value) { + weight_type = turbomind::WeightType::kBF16; + model_file_type = turbomind::FtCudaDataType::BF16; + } + return std::make_pair(weight_type, model_file_type); +} + +template +class MedusaHeadExample { +public: + MedusaHeadExample(size_t batch_size, + int medusa_num_heads, + size_t medusa_num_layers, + size_t hidden_size, + size_t vocab_size, + std::string dir_path, + turbomind::NcclParam tensor_para, + cudaStream_t stream, + turbomind::cublasMMWrapper* cublas_wrapper, + turbomind::IAllocator* allocator): + batch_size_(batch_size), + medusa_num_heads_(medusa_num_heads), + medusa_num_layers_(medusa_num_layers), + hidden_size_(hidden_size), + vocab_size_(vocab_size), + model_(hidden_size, vocab_size, medusa_num_heads, stream, cublas_wrapper, allocator, tensor_para, false), + input_buf_(nullptr), + allocator_(allocator), + rank_(tensor_para.rank_) + { + + auto type = get_type(); + auto weight_type = type.first; + auto model_file_type = type.second; + weights_ = std::make_unique>(medusa_num_heads, + medusa_num_layers, + hidden_size, + vocab_size, + weight_type, + tensor_para.world_size_, + tensor_para.rank_); + weights_->load_model(dir_path, model_file_type); + } + + ~MedusaHeadExample() + { + if (is_allocated) { + allocator_->free((void**)&topk_output_ids_); + allocator_->free((void**)&input_buf_); + } + } + + void forward(int seed = 7) + { + input_buf_ = nullptr; + + input_buf_ = (T*)allocator_->reMalloc(input_buf_, sizeof(T) * batch_size_ * hidden_size_, false); + + topk_output_ids_ = + (int*)allocator_->reMalloc(topk_output_ids_, sizeof(int) * medusa_num_heads_ * batch_size_, false, true); + + size_t total_size = batch_size_ * std::max(hidden_size_, vocab_size_); + buf_host_ = new T[total_size]; + + for (int i = 0; i < total_size; i++) { + buf_host_[i] = i % seed * 1.0; + } + + cudaMemcpy(input_buf_, buf_host_, sizeof(T) * batch_size_ * hidden_size_, cudaMemcpyHostToDevice); + + is_allocated = true; + + turbomind::DataType dtype = turbomind::getTensorType(); + turbomind::TensorMap inputs{ + {"medusa_head_input", {turbomind::MEMORY_GPU, dtype, {batch_size_, hidden_size_}, input_buf_}}, + }; + + // top 1 + turbomind::TensorMap outputs{ + {"medusa_head_output", + {turbomind::MEMORY_GPU, dtype, {medusa_num_heads_, batch_size_, 1}, topk_output_ids_}}, + }; + + model_.forward(&outputs, &inputs, *weights_.get()); + + int* topk_output_ids = outputs.at("medusa_head_output").getPtr(); + if (rank_ == 0) { + for (int i = 0; i < batch_size_ * medusa_num_heads_; i++) { + std::cout << "topk_output_ids[" << i << "]=" << topk_output_ids[i] << '\n'; + } + } + + delete[] buf_host_; + } + +private: + size_t batch_size_; + size_t medusa_num_layers_; + int medusa_num_heads_; + size_t hidden_size_; + size_t vocab_size_; + + T* input_buf_; + + T* buf_host_ = nullptr; + + int* topk_output_ids_; + + turbomind::IAllocator* allocator_; + + turbomind::MedusaHead model_; + std::unique_ptr> weights_; + + bool is_allocated = false; + int rank_ = -1; +}; + +template +void fire(int tp, + int batch_size = 2, + int seed = 7, + size_t medusa_num_heads = 5, + size_t medusa_num_layers = 1, + size_t hidden_size = 5120, + size_t vocab_size = 32000) +{ + std::string dtype; + if (std::is_same::value) { + dtype = "fp16"; + } + else if (std::is_same::value) { + dtype = "bf16"; + } + + std::string dir_path; + if (tp == 1) { + if (std::is_same::value) { + dir_path = "/workdir/medusa_output/fp16/tp1"; + } + else if (std::is_same::value) { + dir_path = "/workdir/medusa_output/bf16/tp1"; + } + } + else if (tp == 2) { + if (std::is_same::value) { + dir_path = "/workdir/medusa_output/fp16/tp2"; + } + else if (std::is_same::value) { + dir_path = "/workdir/medusa_output/bf16/tp2"; + } + } + + std::vector streams(tp); + std::vector>> allocators(tp); + std::vector cublas_handles(tp); + std::vector cublaslt_handles(tp); + std::vector cublas_algo_maps(tp); + std::vector cublas_wrapper_mutexs(tp); + std::vector> cublas_wrappers(tp); + std::vector threads; + std::vector>> models(tp); + std::vector tensor_params(tp); + + turbomind::NcclUid tensor_para_nccl_uid; + turbomind::ftNcclGetUniqueId(tensor_para_nccl_uid); + const auto group_id = turbomind::ftNcclNextGroupId(); + turbomind::ftNcclGroupStart(); + for (int rank = 0; rank < tp; rank++) { + turbomind::check_cuda_error(cudaSetDevice(rank)); + turbomind::ftNcclCommInitRank(tensor_params[rank], rank, tp, tensor_para_nccl_uid); + tensor_params[rank].group_id_ = group_id; + } + turbomind::ftNcclGroupEnd(); + + for (int rank = 0; rank < tp; rank++) { + std::cout << "rank=" << rank << " tp=" << tp << " dtype=" << dtype << " batch=" << batch_size + << " seed=" << seed << '\n'; + + turbomind::check_cuda_error(cudaSetDevice(rank)); + turbomind::check_cuda_error(cudaStreamCreate(&streams[rank])); + allocators[rank] = std::unique_ptr>( + new turbomind::Allocator(rank)); + allocators[rank]->setStream(streams[rank]); + cublasCreate(&cublas_handles[rank]); + cublasLtCreate(&cublaslt_handles[rank]); + cublasSetStream(cublas_handles[rank], streams[rank]); + cublas_algo_maps[rank] = turbomind::cublasAlgoMap(); + cublas_wrappers[rank] = + std::unique_ptr(new turbomind::cublasMMWrapper(cublas_handles[rank], + cublaslt_handles[rank], + streams[rank], + &cublas_algo_maps[rank], + &cublas_wrapper_mutexs[rank], + allocators[rank].get())); + if (std::is_same::value) { + cublas_wrappers[rank]->setFP16GemmConfig(); + } + else if (std::is_same::value) { + cublas_wrappers[rank]->setBF16GemmConfig(); + } + + models[rank] = std::unique_ptr>(new MedusaHeadExample(batch_size, + medusa_num_heads, + medusa_num_layers, + hidden_size, + vocab_size, + dir_path, + tensor_params[rank], + streams[rank], + cublas_wrappers[rank].get(), + allocators[rank].get())); + } + + auto threadForward = [streams, seed](int rank, MedusaHeadExample* model) { + turbomind::check_cuda_error(cudaSetDevice(rank)); + cudaDeviceSynchronize(); + model->forward(seed); + cudaDeviceSynchronize(); + turbomind::check_cuda_error(cudaStreamSynchronize(streams[rank])); + }; + + for (int rank = 0; rank < tp; rank++) { + threads.push_back(std::thread(threadForward, rank, models[rank].get())); + } + for (auto& t : threads) { + t.join(); + } +} + +int main(int argc, char** argv) +{ + std::vector seed_vec{7}; + std::vector batch_vec{1}; + std::vector type_vec{"bf16", "fp16"}; + std::vector tp_vec{1, 2}; + for (const int seed : seed_vec) { + for (const int batch : batch_vec) { + for (const std::string& type : type_vec) { + if (type == "bf16") { + for (const int tp : tp_vec) { + fire<__nv_bfloat16>(tp, batch, seed); + } + } + else if (type == "fp16") { + for (const int tp : tp_vec) { + fire(tp, batch, seed); + } + } + } + } + } + return 0; +} diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 6d2952e0f6..7e2bb5dba0 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -204,6 +204,9 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, engine_params_.extra_tokens_per_iter = reader.GetInteger("llama", "extra_tokens_per_iter", 0); engine_params_.max_prefill_iters = reader.GetInteger("llama", "max_prefill_iters", 1); + medusa_num_heads_ = reader.GetInteger("llama", "medusa_num_heads", 0); + medusa_num_layers_ = reader.GetInteger("llama", "medusa_num_layers", 0); + handleMissingParams(); shared_state_ = std::make_shared::SharedState>(); @@ -308,7 +311,9 @@ std::unique_ptr> LlamaTritonModel::createSh cublas_wrapper.get(), allocator.get(), false, // is_free_buffer_after_forward, - cuda_device_prop_ptr.get()); + cuda_device_prop_ptr.get(), + medusa_num_heads_, + medusa_num_layers_); return std::make_unique>( LlamaTritonSharedModelInstance{std::move(allocator), @@ -368,7 +373,9 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) weight_type_, group_size_, tensor_para_size_, - tensor_para_rank); + tensor_para_rank, + medusa_num_heads_, + medusa_num_layers_); // model inited with model_dir if (model_dir_ != "") { shared_weights_[device_id]->loadModel(model_dir_); @@ -406,7 +413,7 @@ std::string LlamaTritonModel::toString() << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << "\ngroup_size: " << group_size_ - << std::endl; + << "\nmedusa_num_heads: " << medusa_num_heads_ << "\nmedusa_num_layers: " << medusa_num_layers_ << std::endl; return ss.str(); } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index ff086a9099..9ee406dd35 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -117,4 +117,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { std::string model_dir_; ffi_api_lock_ctrl_t ffi_lock_ = nullptr; + + int medusa_num_heads_ = 0; + int medusa_num_layers_ = 0; };