Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Mar 19, 2024
1 parent f132e71 commit a5f8c9a
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 24 deletions.
40 changes: 22 additions & 18 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/models/llama/LlamaBatch.h"
#include "SequenceManager.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/macro.h"
Expand Down Expand Up @@ -1511,18 +1512,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
pf_offset = i;
}
if (medusa_enable_) {
auto& medusa_state = medusa_state_vec_[i];
medusa_state.len = seq.input_length;
if (seq.iter == 0) {
medusa_state.verified_len = seq.input_length;
medusa_state.inited = false;
medusa_state.index = new_index++;
}
else {
medusa_state.verified_len = 0;
medusa_state.inited = true;
medusa_state.index = inited_index++;
}
MedusaInit(medusa_state_vec_, inited_index, new_index, i, seq);
}
}
if (pf_offset < 0) {
Expand Down Expand Up @@ -1594,12 +1584,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)

batched_copy.Submit(stream_);

int dc_batch_size = p ? 0 : pf_offset;
int pf_batch_size = mini_batch_size - dc_batch_size;
if (medusa_enable_) {
dc_batch_size = 0;
pf_batch_size = mini_batch_size;
}
const int dc_batch_size = p ? 0 : pf_offset;
const int pf_batch_size = mini_batch_size - dc_batch_size;

if (rank_ == 0) {
if (pf_batch_size) {
Expand Down Expand Up @@ -1734,6 +1720,24 @@ std::ostream& operator<<(std::ostream& os, const MedusaState& medusa_state)
return os;
}

template<typename T>
void LlamaBatch<T>::MedusaInit(
std::vector<MedusaState>& medusa_state_vec, int& inited_index, int& new_index, const int index, const Sequence& seq)
{
auto& medusa_state = medusa_state_vec[index];
medusa_state.len = seq.input_length;
if (seq.iter == 0) {
medusa_state.verified_len = seq.input_length;
medusa_state.inited = false;
medusa_state.index = new_index++;
}
else {
medusa_state.verified_len = 0;
medusa_state.inited = true;
medusa_state.index = inited_index++;
}
}

template<typename T>
void LlamaBatch<T>::MedusaCopy(const int mini_batch_size, const int first)
{
Expand Down
5 changes: 5 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ class LlamaBatch {
IndexedCopyImpl(nullptr, nullptr, count, cpys...);
}

void MedusaInit(std::vector<MedusaState>& medusa_state_vec,
int& inited_index,
int& new_index,
const int index,
const Sequence& seq);
void MedusaCopy(const int mini_batch_size, const int first);
void MedusaVerify(const int inited_index);
void MedusaGenerate(const int inited_index,
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ void LlamaV2<T>::medusaForward(int* topk_output_ids, const T* input_buf, const s

turbomind::TensorMap outputs{
{"medusa_head_output",
{turbomind::MEMORY_GPU, dtype, {(size_t)medusa_num_heads_, batch_size, 1}, topk_output_ids}},
{turbomind::MEMORY_GPU, dtype, {batch_size, (size_t)medusa_num_heads_, 1}, topk_output_ids}},
};

medusa_head_->forward(&outputs, &inputs, weights_->get_medusa_weight());
Expand Down
9 changes: 8 additions & 1 deletion src/turbomind/models/medusa_plugin/medusa_head.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Zhiwei Bao <[email protected]>

#include "src/turbomind/models/medusa_plugin/medusa_head.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/utils/Tensor.h"
Expand Down Expand Up @@ -78,6 +79,8 @@ void MedusaHead<T>::allocate_buffer(size_t batch_size)
(T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false);
medusa_head_logits_buf_ = (T*)allocator_->reMalloc(
medusa_head_logits_buf_, medusa_num_heads_ * sizeof(T) * batch_size * vocab_size_, false);
topk_output_ids_t_ =
(int*)allocator_->reMalloc(topk_output_ids_t_, sizeof(int) * batch_size * medusa_num_heads_, false);
is_allocated_buffer_ = true;
}

Expand All @@ -88,6 +91,7 @@ void MedusaHead<T>::free_buffer()
allocator_->free((void**)&resblock_buf_);
allocator_->free((void**)&workspace_buf_);
allocator_->free((void**)&medusa_head_logits_buf_);
allocator_->free((void**)&topk_output_ids_t_);
is_allocated_buffer_ = false;
}
}
Expand All @@ -102,7 +106,10 @@ void MedusaHead<T>::top_k(int* h_topk_output_ids, const T* d_input_logits, const
int offset = (int)(ceil(batch_size * vocab_size_ / 4.)) * 4;
int output_size = (int)(ceil(batch_size * k / 4.)) * 4;
int* topk_output_ids = (int*)(((T*)workspace_buf_) + offset);
cudaMemcpy(h_topk_output_ids, topk_output_ids, sizeof(int) * output_size, cudaMemcpyDeviceToHost);
invokeTransposeAxis01(
topk_output_ids_t_, topk_output_ids, medusa_num_heads_, batch_size / medusa_num_heads_, 1, stream_);
cudaMemcpyAsync(h_topk_output_ids, topk_output_ids_t_, sizeof(int) * output_size, cudaMemcpyDeviceToHost, stream_);
cudaStreamSynchronize(stream_);
}

template class MedusaHead<float>;
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/medusa_plugin/medusa_head.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class MedusaHead {
T* resblock_buf_;
void* workspace_buf_;
T* medusa_head_logits_buf_;
int* topk_output_ids_t_;

cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/models/medusa_plugin/medusa_weight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ void MedusaWeight<T>::load_bias(LlamaDenseWeight<T>* weight, const std::string&
template<typename T>
void MedusaWeight<T>::load_model(const std::string& dir_path, FtCudaDataType model_file_type)
{
auto ends_with = [](std::string& text, const std::string& suffix) noexcept {
auto ends_with = [](const std::string& text, const std::string& suffix) noexcept {
return suffix.empty()
|| (text.size() >= suffix.size()
&& std::memcmp(text.data() + (text.size() - suffix.size()), suffix.data(), suffix.size()) == 0);
};
std::string weight_path = dir_path;
if (!ends_with(weight_path, "/")) {
if (!ends_with(dir_path, "/")) {
weight_path.append("/");
}
std::string prefix = "medusa.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MedusaHeadExample {
// top 1
turbomind::TensorMap outputs{
{"medusa_head_output",
{turbomind::MEMORY_GPU, dtype, {medusa_num_heads_, batch_size_, 1}, topk_output_ids_}},
{turbomind::MEMORY_GPU, dtype, {batch_size_, medusa_num_heads_, 1}, topk_output_ids_}},
};

model_.forward(&outputs, &inputs, *weights_.get());
Expand Down Expand Up @@ -258,7 +258,7 @@ void fire(int tp,
int main(int argc, char** argv)
{
std::vector<int> seed_vec{7};
std::vector<int> batch_vec{1};
std::vector<int> batch_vec{2};
std::vector<std::string> type_vec{"bf16", "fp16"};
std::vector<int> tp_vec{1, 2};
for (const int seed : seed_vec) {
Expand Down

0 comments on commit a5f8c9a

Please sign in to comment.