Skip to content

Commit

Permalink
feat: add invokeBatchTopK and invokeMedusaBatchMatch
Browse files Browse the repository at this point in the history
  • Loading branch information
b4b4o authored and zhyncs committed Mar 13, 2024
1 parent fa4fafe commit ddc4169
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 6 deletions.
20 changes: 20 additions & 0 deletions src/turbomind/kernels/reduce_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,26 @@ struct TopK_2 {
}
};

template<>
struct TopK_2<int> {
int p = -1;
int u = std::numeric_limits<int>::min();

__device__ __forceinline__ void insert(int elem, int elem_id)
{
if (elem > u) {
u = elem;
p = elem_id;
}
}

__device__ __forceinline__ void init()
{
p = -1;
u = std::numeric_limits<int>::min();
}
};

template<typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b)
{
Expand Down
176 changes: 176 additions & 0 deletions src/turbomind/kernels/sampling_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -619,4 +619,180 @@ template void invokeTopKTopPSampling(void* workspace,
const int* end_ids,
cudaStream_t stream);

template<typename T, int BLOCK_SIZE_>
__global__ void topk_only(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size,
const int* end_ids,
const bool* skip_decode)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int tid = threadIdx.x;
const int bid = blockIdx.x;

const int batch_id = bid;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k;
const int tmp_log_buf_index = batch_id * vocab_size;
const int tmp_topk_buf_index = batch_id * max_top_k;

TopK_2<T> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
int index = elem_id + tmp_log_buf_index;
tmp_log_probs[index] = log_probs[index];
}

for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int elem_id = tid; elem_id < vocab_size; elem_id += BLOCK_SIZE_) {
int index = elem_id + tmp_log_buf_index;
partial.insert(tmp_log_probs[index], index);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);

if (tid == 0) {
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p % vocab_size;
topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL;
}
__syncthreads();
}
}

#ifdef _MSC_VER
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
break; \
}
#else
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
case K_MIN ... K_MAX: \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
break;
#endif

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode)
{

const int vocab_size = vocab_size_padded;
int temp_log_probs_buf_size = batch_size * vocab_size;
int topk_tmp_ids_buf_size = batch_size * max_top_k;
int topk_tmp_val_buf_size = batch_size * max_top_k;

temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;

if (workspace == nullptr) {
workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size
+ sizeof(T) * topk_tmp_val_buf_size;
return;
}

T* temp_log_probs = (T*)workspace;
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
#ifdef _MSC_VER
do {
ONLY_TOPK_CASE_K(1, 16, 128 * 8);
ONLY_TOPK_CASE_K(17, 32, 256 * 8);
ONLY_TOPK_CASE_K(33, 64, 256 * 8);
ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
} while (0);
#else
switch (max_top_k) {
ONLY_TOPK_CASE_K(1, 16, 128 * 8);
ONLY_TOPK_CASE_K(17, 32, 256 * 8);
ONLY_TOPK_CASE_K(33, 64, 256 * 8);
ONLY_TOPK_CASE_K(65, 1024, 256 * 8);
default:
throw std::domain_error(fmtstr("only top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
}
#endif
}

#undef ONLY_TOPK_CASE_K

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const half* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const float* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);

#ifdef ENABLE_BF16
template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const __nv_bfloat16* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
#endif
} // namespace turbomind
12 changes: 12 additions & 0 deletions src/turbomind/kernels/sampling_topk_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,16 @@ void invokeTopKTopPSampling(void* workspace,
const int* end_ids,
cudaStream_t stream);

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
} // namespace turbomind
66 changes: 66 additions & 0 deletions src/turbomind/models/llama/llama_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,70 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud
});
}

template<int BLOCK_SIZE_>
__global__ void medusaBatchedMatchKernel(const int* __restrict__ input_ids,
const int* __restrict__ output_ids,
int* match_idx,
int* match_length,
const int path_num,
int size)
{
//[b, path_num, 1 + head_num]
const int length = size + 1;
const int limit_r = gridDim.x * path_num * length;
const int bid = blockIdx.x; // (0, batch_size)
const int tid = threadIdx.x; // (0, BLOCK_SIZE_)

typedef cub::BlockReduce<TopK_2<int>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

TopK_2<int> partial;
partial.init();

for (int idx = tid; idx < path_num; idx += BLOCK_SIZE_) {
int start_id = bid * path_num * length + idx * length; // belong to (bid, path_id)
int accumulate_length = 0;
for (int i = 0; i < size && (start_id + i) < limit_r; ++i) {
if (input_ids[start_id + i + 1] == output_ids[start_id + i]) {
++accumulate_length;
}
else {
break;
}
}
partial.insert(accumulate_length, idx);
}

TopK_2<int> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<int>);

if (tid == 0) {
const int index = bid;
match_idx[index] = total.p;
match_length[index] = total.u;
}
__syncthreads();
}

void invokeMedusaBatchedMatchKernel(const int* input_ids,
const int* output_ids,
int* max_match_idx,
int* max_match_length,
int batch_size,
int path_num,
int medusa_head_num,
cudaStream_t stream)
{
// inputs:
// input_ids: [batch_size, path_num, 1 + medusa_head_num]
// output_ids: [batch_size, path_num, 1 + medusa_head_num]
// outputs:
// max_match_idx: [batch_size]
// max_match_length: [batch_size]
dim3 grid, block;
grid.x = batch_size;
block.x = 64;
medusaBatchedMatchKernel<64><<<grid, block, 0, stream>>>(
input_ids, output_ids, max_match_idx, max_match_length, path_num, medusa_head_num);
}

} // namespace turbomind
9 changes: 9 additions & 0 deletions src/turbomind/models/llama/llama_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,13 @@ inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_
TM_LOG_ERROR("--------> rank = %d, step = %d, seq_len = %d <--------", tp_rank, step, h_seq_len);
}

void invokeMedusaBatchedMatchKernel(const int* input_ids,
const int* output_ids,
int* max_match_idx,
int* max_match_length,
int batch_size,
int path_num,
int medusa_head_num,
cudaStream_t stream);

} // namespace turbomind
Loading

0 comments on commit ddc4169

Please sign in to comment.