Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Mar 13, 2024
1 parent ddc4169 commit 2a27648
Show file tree
Hide file tree
Showing 21 changed files with 759 additions and 167 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class TurbomindModelConfig:
rope_scaling_factor: float = 0.0
use_dynamic_ntk: int = 0
use_logn_attn: int = 0
medusa_num_heads: int = 0
medusa_num_layers: int = 0

@classmethod
def from_dict(cls, env, allow_none=False):
Expand Down
121 changes: 41 additions & 80 deletions src/turbomind/kernels/sampling_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -620,28 +620,21 @@ template void invokeTopKTopPSampling(void* workspace,
cudaStream_t stream);

template<typename T, int BLOCK_SIZE_>
__global__ void topk_only(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size,
const int* end_ids,
const bool* skip_decode)
__global__ void topk(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const int max_top_k,
const int vocab_size)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int tid = threadIdx.x;
const int bid = blockIdx.x;

const int batch_id = bid;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k;
const int batch_id = bid;
const int k = max_top_k;
const int tmp_log_buf_index = batch_id * vocab_size;
const int tmp_topk_buf_index = batch_id * max_top_k;

Expand Down Expand Up @@ -676,46 +669,26 @@ __global__ void topk_only(const T* __restrict log_probs,
#ifdef _MSC_VER
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
if (K_MIN <= max_top_k && max_top_k <= K_MAX) { \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
topk<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>( \
log_probs, temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, max_top_k, vocab_size); \
break; \
}
#else
#define ONLY_TOPK_CASE_K(K_MIN, K_MAX, BLOCK_SIZE_) \
case K_MIN ... K_MAX: \
topk_only<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
topk<T, BLOCK_SIZE_><<<batch_size, BLOCK_SIZE_, 0, stream>>>( \
log_probs, temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, max_top_k, vocab_size); \
break;
#endif

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode)
void invokeBatchTopK(void* workspace,
size_t& workspace_size,
const T* log_probs,
const int max_top_k,
const int vocab_size_padded,
cudaStream_t stream,
const int batch_size)
{

const int vocab_size = vocab_size_padded;
Expand Down Expand Up @@ -758,41 +731,29 @@ void invokeBatchTopKOnly(void* workspace,

#undef ONLY_TOPK_CASE_K

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const half* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);

template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const float* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template void invokeBatchTopK(void* workspace,
size_t& workspace_size,
const half* log_probs,
const int max_top_k,
const int vocab_size_padded,
cudaStream_t stream,
const int batch_size);

template void invokeBatchTopK(void* workspace,
size_t& workspace_size,
const float* log_probs,
const int max_top_k,
const int vocab_size_padded,
cudaStream_t stream,
const int batch_size);

#ifdef ENABLE_BF16
template void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const __nv_bfloat16* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template void invokeBatchTopK(void* workspace,
size_t& workspace_size,
const __nv_bfloat16* log_probs,
const int max_top_k,
const int vocab_size_padded,
cudaStream_t stream,
const int batch_size);
#endif
} // namespace turbomind
18 changes: 7 additions & 11 deletions src/turbomind/kernels/sampling_topk_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,11 @@ void invokeTopKTopPSampling(void* workspace,
cudaStream_t stream);

template<typename T>
void invokeBatchTopKOnly(void* workspace,
size_t& workspace_size,
const T* log_probs,
bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
void invokeBatchTopK(void* workspace,
size_t& workspace_size,
const T* log_probs,
const int max_top_k,
const int vocab_size_padded,
cudaStream_t stream,
const int batch_size);
} // namespace turbomind
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
nccl_utils
cuda_utils
logger
llama_fmha)
llama_fmha
Medusa)


add_executable(llama_gemm llama_gemm.cc)
Expand Down
Loading

0 comments on commit 2a27648

Please sign in to comment.