-
Notifications
You must be signed in to change notification settings - Fork 430
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: porting medusa head and resblock with tp support
- Loading branch information
Showing
12 changed files
with
699 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ | |
# limitations under the License. | ||
|
||
add_subdirectory(llama) | ||
add_subdirectory(medusa_plugin) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
cmake_minimum_required(VERSION 3.8) | ||
|
||
find_package(CUDAToolkit REQUIRED) | ||
|
||
add_library(Medusa STATIC | ||
medusa_weight.cc | ||
res_block.cc | ||
medusa_head.cc) | ||
|
||
set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
// Yineng Zhang <[email protected]> | ||
// Zhiwei Bao <[email protected]> | ||
|
||
#include "src/turbomind/models/medusa_plugin/medusa_head.h" | ||
#include "src/turbomind/models/llama/LlamaNcclGuard.h" | ||
#include "src/turbomind/utils/Tensor.h" | ||
#include "src/turbomind/utils/cublasMMWrapper.h" | ||
|
||
namespace turbomind { | ||
|
||
template<typename T> | ||
MedusaHead<T>::MedusaHead(size_t in_size, | ||
size_t out_size, | ||
size_t medusa_num_heads, | ||
cudaStream_t stream, | ||
cublasMMWrapper* cublas_wrapper, | ||
IAllocator* allocator, | ||
NcclParam tensor_para, | ||
bool is_free_buffer_after_forward): | ||
in_size_(in_size), | ||
out_size_(out_size), | ||
medusa_num_heads_(medusa_num_heads), | ||
stream_(stream), | ||
cublas_wrapper_(cublas_wrapper), | ||
allocator_(allocator), | ||
tensor_para_(tensor_para), | ||
is_free_buffer_after_forward_(is_free_buffer_after_forward) | ||
{ | ||
resblock_ = std::make_unique<ResBlock<T>>(in_size_, stream_, cublas_wrapper_, tensor_para_); | ||
linear_ = std::make_unique<LlamaLinear<T>>(cublas_wrapper_, stream_); | ||
} | ||
|
||
template<typename T> | ||
void MedusaHead<T>::forward(TensorMap* output_tensors, | ||
const TensorMap* input_tensors, | ||
const MedusaWeight<T>& medusa_weight) | ||
{ | ||
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; | ||
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>(); | ||
std::vector<T*>* medusa_head_logits_vec = output_tensors->at("medusa_head_output").getPtr<std::vector<T*>>(); | ||
// TODO parallelize this loop | ||
for (int i = 0; i < medusa_num_heads_; i++) { | ||
T* medusa_head_logits = (*medusa_head_logits_vec)[i]; | ||
forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i); | ||
} | ||
} | ||
|
||
template<typename T> | ||
void MedusaHead<T>::forward(T* medusa_head_output, | ||
const T* medusa_head_input, | ||
size_t batch_size, | ||
const MedusaWeight<T>& medusa_weight, | ||
int head_id) | ||
{ | ||
allocate_buffer(batch_size); | ||
// TODO support multi medusa_num_layers | ||
resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]); | ||
linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]); | ||
|
||
if (tensor_para_.world_size_ > 1) { | ||
NcclGuard nccl_guard(tensor_para_, stream_); | ||
ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * out_size_, tensor_para_, stream_); | ||
sync_check_cuda_error(); | ||
} | ||
|
||
free_buffer(); | ||
} | ||
|
||
template<typename T> | ||
void MedusaHead<T>::allocate_buffer(size_t batch_size) | ||
{ | ||
resblock_buf_ = | ||
(T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false); | ||
is_allocated_buffer_ = true; | ||
} | ||
|
||
template<typename T> | ||
void MedusaHead<T>::free_buffer() | ||
{ | ||
if (is_free_buffer_after_forward_ && is_allocated_buffer_) { | ||
allocator_->free((void**)&resblock_buf_); | ||
is_allocated_buffer_ = false; | ||
} | ||
} | ||
|
||
template class MedusaHead<float>; | ||
template class MedusaHead<half>; | ||
#ifdef ENABLE_BF16 | ||
template class MedusaHead<__nv_bfloat16>; | ||
#endif | ||
|
||
} // namespace turbomind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
// Yineng Zhang <[email protected]> | ||
// Zhiwei Bao <[email protected]> | ||
|
||
#pragma once | ||
|
||
#include "src/turbomind/models/medusa_plugin/medusa_weight.h" | ||
#include "src/turbomind/models/medusa_plugin/res_block.h" | ||
#include "src/turbomind/utils/cublasMMWrapper.h" | ||
#include "src/turbomind/utils/nccl_utils.h" | ||
#include <cuda_runtime.h> | ||
#include <memory> | ||
|
||
namespace turbomind { | ||
|
||
template<typename T> | ||
class MedusaHead { | ||
public: | ||
MedusaHead(size_t in_size, | ||
size_t out_size, | ||
size_t medusa_num_heads, | ||
cudaStream_t stream, | ||
cublasMMWrapper* cublas_wrapper, | ||
IAllocator* allocator, | ||
NcclParam tensor_para, | ||
bool is_free_buffer_after_forward = false); | ||
~MedusaHead() = default; | ||
MedusaHead(const MedusaHead&) = delete; | ||
MedusaHead& operator=(const MedusaHead&) = delete; | ||
|
||
void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight<T>& medusa_weight); | ||
void forward(T* medusa_head_output, | ||
const T* medusa_head_input, | ||
size_t batch_size, | ||
const MedusaWeight<T>& medusa_weight, | ||
int head_id); | ||
|
||
private: | ||
void allocate_buffer(size_t batch_size); | ||
void free_buffer(); | ||
|
||
private: | ||
size_t in_size_; | ||
size_t out_size_; | ||
size_t medusa_num_heads_; | ||
|
||
std::unique_ptr<ResBlock<T>> resblock_; | ||
std::unique_ptr<LlamaLinear<T>> linear_; | ||
|
||
T* resblock_buf_; | ||
|
||
cudaStream_t stream_; | ||
cublasMMWrapper* cublas_wrapper_; | ||
IAllocator* allocator_; | ||
|
||
NcclParam tensor_para_; | ||
|
||
bool is_allocated_buffer_ = false; | ||
bool is_free_buffer_after_forward_ = false; | ||
}; | ||
} // namespace turbomind |
Oops, something went wrong.