Skip to content

Commit

Permalink
feat: porting medusa head and resblock with tp support
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Mar 1, 2024
1 parent f81404a commit 11f8043
Show file tree
Hide file tree
Showing 12 changed files with 699 additions and 0 deletions.
78 changes: 78 additions & 0 deletions src/turbomind/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,82 @@ INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half);
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif

template<template<typename T> class Activation, typename T, typename BT>
__global__ void fused_bias_residual_activation(
T* out, const BT* __restrict bias, const T* __restrict residual, int m, int n, int tp_num, int tp_offset)
{
const bool with_bias = bias != nullptr;
const bool with_residual = residual != nullptr;

for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) {
T val;

val = out[id];

if (with_bias) {
T bias_val = static_cast<T>(bias[id % n]);
val = add(val, bias_val);
}

val = cuda_cast<T>(Activation<T>::apply(val));

if (with_residual) {
T residual_val = static_cast<T>(residual[id % n + (id - id % n) * tp_num + tp_offset]);
val = add(val, residual_val);
}

out[id] = val;
}
}

template<template<typename T> class Activation, typename T, typename BT>
void invokeFusedBiasResidualActivation(T* out,
const BT* bias,
const T* residual,
const int m,
const int n,
cudaStream_t stream,
const int tp_num,
const int tp_offset)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
using PT = typename packed_type<T>::type;
constexpr int packed_elems = num_elems<PT>::value;
using PBT = typename packed_as<BT, packed_elems>::type;

dim3 block, grid;
if (n / 4 / packed_elems <= 1024) {
block.x = n / 4 / packed_elems;
grid.x = m;
}
else {
block.x = 1024;
grid.x = ceil(m * n / 1024.);
}
fused_bias_residual_activation<Activation><<<grid, block, 0, stream>>>(reinterpret_cast<PT*>(out),
reinterpret_cast<const PBT*>(bias),
reinterpret_cast<const PT*>(residual),
m,
n / packed_elems,
tp_num,
tp_offset / packed_elems);
sync_check_cuda_error();
}

#define INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(Activation, T, BT) \
template void invokeFusedBiasResidualActivation<Activation, T, BT>(T * out, \
const BT* bias, \
const T* residual, \
const int m, \
const int n, \
cudaStream_t stream, \
const int tp_num, \
const int tp_offset);

INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, float, float);
INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif

} // namespace turbomind
10 changes: 10 additions & 0 deletions src/turbomind/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,14 @@ void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStre
template<typename T>
void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream);

template<template<typename T> class Activation, typename T, typename BT>
void invokeFusedBiasResidualActivation(T* out,
const BT* bias,
const T* residual,
const int m,
const int n,
cudaStream_t stream,
const int tp_num,
const int tp_offset);

} // namespace turbomind
1 change: 1 addition & 0 deletions src/turbomind/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

add_subdirectory(llama)
add_subdirectory(medusa_plugin)
12 changes: 12 additions & 0 deletions src/turbomind/models/medusa_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.8)

find_package(CUDAToolkit REQUIRED)

add_library(Medusa STATIC
medusa_weight.cc
res_block.cc
medusa_head.cc)

set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON)
93 changes: 93 additions & 0 deletions src/turbomind/models/medusa_plugin/medusa_head.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Yineng Zhang <[email protected]>
// Zhiwei Bao <[email protected]>

#include "src/turbomind/models/medusa_plugin/medusa_head.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cublasMMWrapper.h"

namespace turbomind {

template<typename T>
MedusaHead<T>::MedusaHead(size_t in_size,
size_t out_size,
size_t medusa_num_heads,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
NcclParam tensor_para,
bool is_free_buffer_after_forward):
in_size_(in_size),
out_size_(out_size),
medusa_num_heads_(medusa_num_heads),
stream_(stream),
cublas_wrapper_(cublas_wrapper),
allocator_(allocator),
tensor_para_(tensor_para),
is_free_buffer_after_forward_(is_free_buffer_after_forward)
{
resblock_ = std::make_unique<ResBlock<T>>(in_size_, stream_, cublas_wrapper_, tensor_para_);
linear_ = std::make_unique<LlamaLinear<T>>(cublas_wrapper_, stream_);
}

template<typename T>
void MedusaHead<T>::forward(TensorMap* output_tensors,
const TensorMap* input_tensors,
const MedusaWeight<T>& medusa_weight)
{
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
std::vector<T*>* medusa_head_logits_vec = output_tensors->at("medusa_head_output").getPtr<std::vector<T*>>();
// TODO parallelize this loop
for (int i = 0; i < medusa_num_heads_; i++) {
T* medusa_head_logits = (*medusa_head_logits_vec)[i];
forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i);
}
}

template<typename T>
void MedusaHead<T>::forward(T* medusa_head_output,
const T* medusa_head_input,
size_t batch_size,
const MedusaWeight<T>& medusa_weight,
int head_id)
{
allocate_buffer(batch_size);
// TODO support multi medusa_num_layers
resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]);
linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]);

if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * out_size_, tensor_para_, stream_);
sync_check_cuda_error();
}

free_buffer();
}

template<typename T>
void MedusaHead<T>::allocate_buffer(size_t batch_size)
{
resblock_buf_ =
(T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false);
is_allocated_buffer_ = true;
}

template<typename T>
void MedusaHead<T>::free_buffer()
{
if (is_free_buffer_after_forward_ && is_allocated_buffer_) {
allocator_->free((void**)&resblock_buf_);
is_allocated_buffer_ = false;
}
}

template class MedusaHead<float>;
template class MedusaHead<half>;
#ifdef ENABLE_BF16
template class MedusaHead<__nv_bfloat16>;
#endif

} // namespace turbomind
61 changes: 61 additions & 0 deletions src/turbomind/models/medusa_plugin/medusa_head.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Yineng Zhang <[email protected]>
// Zhiwei Bao <[email protected]>

#pragma once

#include "src/turbomind/models/medusa_plugin/medusa_weight.h"
#include "src/turbomind/models/medusa_plugin/res_block.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/nccl_utils.h"
#include <cuda_runtime.h>
#include <memory>

namespace turbomind {

template<typename T>
class MedusaHead {
public:
MedusaHead(size_t in_size,
size_t out_size,
size_t medusa_num_heads,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
NcclParam tensor_para,
bool is_free_buffer_after_forward = false);
~MedusaHead() = default;
MedusaHead(const MedusaHead&) = delete;
MedusaHead& operator=(const MedusaHead&) = delete;

void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight<T>& medusa_weight);
void forward(T* medusa_head_output,
const T* medusa_head_input,
size_t batch_size,
const MedusaWeight<T>& medusa_weight,
int head_id);

private:
void allocate_buffer(size_t batch_size);
void free_buffer();

private:
size_t in_size_;
size_t out_size_;
size_t medusa_num_heads_;

std::unique_ptr<ResBlock<T>> resblock_;
std::unique_ptr<LlamaLinear<T>> linear_;

T* resblock_buf_;

cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
IAllocator* allocator_;

NcclParam tensor_para_;

bool is_allocated_buffer_ = false;
bool is_free_buffer_after_forward_ = false;
};
} // namespace turbomind
Loading

0 comments on commit 11f8043

Please sign in to comment.