diff --git a/src/turbomind/kernels/attention/array_ops.h b/src/turbomind/kernels/attention/array_ops.h index 849e42920..8a3efce34 100644 --- a/src/turbomind/kernels/attention/array_ops.h +++ b/src/turbomind/kernels/attention/array_ops.h @@ -234,30 +234,18 @@ struct FastRoPE { Array inv_freq_; - __device__ FastRoPE(int idx, D dims, float base, std::integral_constant) + __device__ FastRoPE(int idx, D dims, float base, float ti_scale, std::integral_constant) { - // constexpr float inv_dims = 1.f / dims; - // PRAGMA_UNROLL - // for (int i = 0; i < N; i += 2) { - // inv_freq_[i / 2] = fdividef(1.f, powf(base, (idx + i) * inv_dims)); - // } - - // const float scale_factor = log2f(base) / dims; - // PRAGMA_UNROLL - // for (int i = 0; i < N; i += 2) { - // inv_freq_[i / 2] = fdividef(1.f, exp2f((idx + i) * scale_factor)); - // } - // ! Check compiler CSE const float scale_factor = -log2f(base) / dims; PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { - inv_freq_[i / 2] = exp2f((idx + i) * scale_factor); + inv_freq_[i / 2] = ti_scale * exp2f((idx + i) * scale_factor); } } template - __device__ void apply(Array& x, int timestep) + __device__ void apply(Array& x, float timestep) { PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 1d9b78f88..7d83ae885 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -50,6 +50,7 @@ struct AttentionParams { int rotary_embedding_dim; float rotary_embedding_base; int max_position_embeddings; + float rope_ti_scale; // used for linear RoPE scaling // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 015d11877..c724c1e4d 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -159,8 +159,11 @@ struct AttentionUniversal { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope( - di, std::integral_constant{}, rope_base, std::integral_constant{}); + FastRoPE rope(di, + std::integral_constant{}, + rope_base, + params.rope_ti_scale, + std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len; diff --git a/src/turbomind/kernels/attention/kv_cache_utils.cu b/src/turbomind/kernels/attention/kv_cache_utils.cu index b5e475c3d..3bb3db7c8 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils.cu @@ -18,6 +18,7 @@ __global__ void __launch_bounds__(128) ProcessKV(Tkv** blocks, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -118,7 +119,11 @@ __global__ void __launch_bounds__(128) ProcessKV(Tkv** blocks, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, std::integral_constant{}, base, std::integral_constant{}); + FastRoPE rope(di, + std::integral_constant{}, + base, + rope_ti_scale, + std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local @@ -177,6 +182,7 @@ void invokeProcessKV(void** blocks, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -211,6 +217,7 @@ void invokeProcessKV(void** blocks, cu_k_len, cu_block_num, rope_base, + rope_ti_scale, stride_b, stride_c, stride_h, @@ -234,6 +241,7 @@ template void invokeProcessKV(void** blocks, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -257,6 +265,7 @@ template void invokeProcessKV(void** blocks, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -279,6 +288,7 @@ __global__ void __launch_bounds__(128) flattenKV(T* k, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -358,7 +368,11 @@ __global__ void __launch_bounds__(128) flattenKV(T* k, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, std::integral_constant{}, base, std::integral_constant{}); + FastRoPE rope(di, + std::integral_constant{}, + base, + rope_ti_scale, + std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local @@ -390,6 +404,7 @@ void invokeFlattenKV(T* k, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -422,6 +437,7 @@ void invokeFlattenKV(T* k, cu_k_len, cu_block_num, rope_base, + rope_ti_scale, stride_b, stride_c, stride_h, @@ -442,6 +458,7 @@ template void invokeFlattenKV(half* k, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, @@ -463,6 +480,7 @@ template void invokeFlattenKV(nv_bfloat16* k, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, int stride_h, diff --git a/src/turbomind/kernels/attention/kv_cache_utils.h b/src/turbomind/kernels/attention/kv_cache_utils.h index 5558b0525..6c17b7b4c 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils.h +++ b/src/turbomind/kernels/attention/kv_cache_utils.h @@ -16,6 +16,7 @@ void invokeProcessKV(void** blocks, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, // cumulative len int stride_h, @@ -42,6 +43,7 @@ void invokeProcessKV_(const AttentionParams& params) params.cu_k_len, params.cu_block_cnts, params.rope_theta, + params.rope_ti_scale, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -64,6 +66,7 @@ void invokeFlattenKV(T* k, const int* cu_k_len, const int* cu_block_num, const float* rope_base, + float rope_ti_scale, int stride_b, int stride_c, // cumulative len int stride_h, @@ -88,6 +91,7 @@ void invokeFlattenKV_(const AttentionParams& params, int sum_k_len) params.cu_k_len, params.cu_block_cnts, nullptr, // params.rope_theta, + params.rope_ti_scale, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 6c460af6c..08920d5ed 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -102,6 +102,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, D] cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, + 1., 2 * head_num * seq_len, 0, seq_len, @@ -127,6 +128,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, D] cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, + 1., 2 * head_num * seq_len, 0, seq_len, @@ -206,7 +208,7 @@ int test_attention() constexpr size_t kSequenceLen = 0; constexpr int kMaxSplitK = 1; - constexpr int kBlockSz = 128; + constexpr int kBlockSz = 128; #endif @@ -371,6 +373,7 @@ int test_attention() params.rotary_embedding_dim = kHeadDim; params.rotary_embedding_base = kRoPEBase; + params.rope_ti_scale = 1.; params.split_cnt = split_cnt.data().get(); params.partial_L = partial_L.data().get(); @@ -470,6 +473,7 @@ int test_attention() cu_kv_lens.data().get(), cu_block_cnts.data().get(), nullptr, // DECODING ? nullptr : params.rope_theta, + 1., KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index a8f5a6a1a..d34251542 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -9,8 +9,8 @@ struct LlamaAttentionParams { float rotary_embedding_base; int max_position_embeddings; float rope_scaling_factor; - // bool use_dynamic_ntk; - bool use_logn_attn; + bool use_dynamic_ntk; + bool use_logn_attn; }; struct EngineParams {