Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support Medusa #1231

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/turbomind/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,82 @@ INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half);
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif

template<template<typename T> class Activation, typename T, typename BT>
__global__ void fused_bias_residual_activation(
T* out, const BT* __restrict bias, const T* __restrict residual, int m, int n, int tp_num, int tp_offset)
{
const bool with_bias = bias != nullptr;
const bool with_residual = residual != nullptr;

for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) {
T val;

val = out[id];

if (with_bias) {
T bias_val = static_cast<T>(bias[id % n]);
val = add(val, bias_val);
}

val = cuda_cast<T>(Activation<T>::apply(val));

if (with_residual) {
T residual_val = static_cast<T>(residual[id % n + (id - id % n) * tp_num + tp_offset]);
val = add(val, residual_val);
}

out[id] = val;
}
}

template<template<typename T> class Activation, typename T, typename BT>
void invokeFusedBiasResidualActivation(T* out,
const BT* bias,
const T* residual,
const int m,
const int n,
cudaStream_t stream,
const int tp_num,
const int tp_offset)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
using PT = typename packed_type<T>::type;
constexpr int packed_elems = num_elems<PT>::value;
using PBT = typename packed_as<BT, packed_elems>::type;

dim3 block, grid;
if (n / 4 / packed_elems <= 1024) {
block.x = n / 4 / packed_elems;
grid.x = m;
}
else {
block.x = 1024;
grid.x = ceil(m * n / 1024.);
}
fused_bias_residual_activation<Activation><<<grid, block, 0, stream>>>(reinterpret_cast<PT*>(out),
reinterpret_cast<const PBT*>(bias),
reinterpret_cast<const PT*>(residual),
m,
n / packed_elems,
tp_num,
tp_offset / packed_elems);
sync_check_cuda_error();
}

#define INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(Activation, T, BT) \
template void invokeFusedBiasResidualActivation<Activation, T, BT>(T * out, \
const BT* bias, \
const T* residual, \
const int m, \
const int n, \
cudaStream_t stream, \
const int tp_num, \
const int tp_offset);

INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, float, float);
INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif

} // namespace turbomind
10 changes: 10 additions & 0 deletions src/turbomind/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,14 @@ void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStre
template<typename T>
void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream);

template<template<typename T> class Activation, typename T, typename BT>
void invokeFusedBiasResidualActivation(T* out,
const BT* bias,
const T* residual,
const int m,
const int n,
cudaStream_t stream,
const int tp_num,
const int tp_offset);

} // namespace turbomind
1 change: 1 addition & 0 deletions src/turbomind/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

add_subdirectory(llama)
add_subdirectory(medusa_plugin)
12 changes: 12 additions & 0 deletions src/turbomind/models/medusa_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.8)

find_package(CUDAToolkit REQUIRED)

add_library(Medusa STATIC
medusa_weight.cc
res_block.cc
medusa_head.cc)

set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON)
93 changes: 93 additions & 0 deletions src/turbomind/models/medusa_plugin/medusa_head.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Yineng Zhang <[email protected]>
// Zhiwei Bao <[email protected]>

#include "src/turbomind/models/medusa_plugin/medusa_head.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cublasMMWrapper.h"

namespace turbomind {

template<typename T>
MedusaHead<T>::MedusaHead(size_t in_size,
size_t vocab_size,
size_t medusa_num_heads,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
NcclParam tensor_para,
bool is_free_buffer_after_forward):
in_size_(in_size),
vocab_size_(vocab_size),
medusa_num_heads_(medusa_num_heads),
stream_(stream),
cublas_wrapper_(cublas_wrapper),
allocator_(allocator),
tensor_para_(tensor_para),
is_free_buffer_after_forward_(is_free_buffer_after_forward)
{
resblock_ = std::make_unique<ResBlock<T>>(in_size_, stream_, cublas_wrapper_, tensor_para_);
linear_ = std::make_unique<LlamaLinear<T>>(cublas_wrapper_, stream_);
}

template<typename T>
void MedusaHead<T>::forward(TensorMap* output_tensors,
const TensorMap* input_tensors,
const MedusaWeight<T>& medusa_weight)
{
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr<T>();
// TODO parallelize this loop
for (int i = 0; i < medusa_num_heads_; i++) {
T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_;
forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i);
}
}

template<typename T>
void MedusaHead<T>::forward(T* medusa_head_output,
const T* medusa_head_input,
size_t batch_size,
const MedusaWeight<T>& medusa_weight,
int head_id)
{
allocate_buffer(batch_size);
// TODO support multi medusa_num_layers
resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]);
linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]);

if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * vocab_size_, tensor_para_, stream_);
sync_check_cuda_error();
}

free_buffer();
}

template<typename T>
void MedusaHead<T>::allocate_buffer(size_t batch_size)
{
resblock_buf_ =
(T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false);
is_allocated_buffer_ = true;
}

template<typename T>
void MedusaHead<T>::free_buffer()
{
if (is_free_buffer_after_forward_ && is_allocated_buffer_) {
allocator_->free((void**)&resblock_buf_);
is_allocated_buffer_ = false;
}
}

template class MedusaHead<float>;
template class MedusaHead<half>;
#ifdef ENABLE_BF16
template class MedusaHead<__nv_bfloat16>;
#endif

} // namespace turbomind
61 changes: 61 additions & 0 deletions src/turbomind/models/medusa_plugin/medusa_head.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Yineng Zhang <[email protected]>
// Zhiwei Bao <[email protected]>

#pragma once

#include "src/turbomind/models/medusa_plugin/medusa_weight.h"
#include "src/turbomind/models/medusa_plugin/res_block.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/nccl_utils.h"
#include <cuda_runtime.h>
#include <memory>

namespace turbomind {

template<typename T>
class MedusaHead {
public:
MedusaHead(size_t in_size,
size_t out_size,
size_t medusa_num_heads,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
NcclParam tensor_para,
bool is_free_buffer_after_forward = false);
~MedusaHead() = default;
MedusaHead(const MedusaHead&) = delete;
MedusaHead& operator=(const MedusaHead&) = delete;

void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight<T>& medusa_weight);
void forward(T* medusa_head_output,
const T* medusa_head_input,
size_t batch_size,
const MedusaWeight<T>& medusa_weight,
int head_id);

private:
void allocate_buffer(size_t batch_size);
void free_buffer();

private:
size_t in_size_;
size_t vocab_size_;
size_t medusa_num_heads_;

std::unique_ptr<ResBlock<T>> resblock_;
std::unique_ptr<LlamaLinear<T>> linear_;

T* resblock_buf_;

cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
IAllocator* allocator_;

NcclParam tensor_para_;

bool is_allocated_buffer_ = false;
bool is_free_buffer_after_forward_ = false;
};
} // namespace turbomind
Loading
Loading