Skip to content

Commit

Permalink
add linear rope scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Mar 4, 2024
1 parent 586b7c7 commit 7c6773e
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 22 deletions.
18 changes: 3 additions & 15 deletions src/turbomind/kernels/attention/array_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,18 @@ struct FastRoPE {

Array<float, N / 2> inv_freq_;

__device__ FastRoPE(int idx, D dims, float base, std::integral_constant<int, N>)
__device__ FastRoPE(int idx, D dims, float base, float ti_scale, std::integral_constant<int, N>)
{
// constexpr float inv_dims = 1.f / dims;
// PRAGMA_UNROLL
// for (int i = 0; i < N; i += 2) {
// inv_freq_[i / 2] = fdividef(1.f, powf(base, (idx + i) * inv_dims));
// }

// const float scale_factor = log2f(base) / dims;
// PRAGMA_UNROLL
// for (int i = 0; i < N; i += 2) {
// inv_freq_[i / 2] = fdividef(1.f, exp2f((idx + i) * scale_factor));
// }

// ! Check compiler CSE
const float scale_factor = -log2f(base) / dims;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
inv_freq_[i / 2] = exp2f((idx + i) * scale_factor);
inv_freq_[i / 2] = ti_scale * exp2f((idx + i) * scale_factor);
}
}

template<typename T>
__device__ void apply(Array<T, N>& x, int timestep)
__device__ void apply(Array<T, N>& x, float timestep)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct AttentionParams {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
float rope_ti_scale; // used for linear RoPE scaling

// log(n) attention
bool use_logn_attn;
Expand Down
7 changes: 5 additions & 2 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ struct AttentionUniversal {
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
FastRoPE rope(
di, std::integral_constant<int, kHeadDim>{}, rope_base, std::integral_constant<int, kVecSize>{});
FastRoPE rope(di,
std::integral_constant<int, kHeadDim>{},
rope_base,
params.rope_ti_scale,
std::integral_constant<int, kVecSize>{});
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;
Expand Down
22 changes: 20 additions & 2 deletions src/turbomind/kernels/attention/kv_cache_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ __global__ void __launch_bounds__(128) ProcessKV(Tkv** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand Down Expand Up @@ -118,7 +119,11 @@ __global__ void __launch_bounds__(128) ProcessKV(Tkv** blocks,
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
FastRoPE rope(di, std::integral_constant<int, HeadDim>{}, base, std::integral_constant<int, kVecSize>{});
FastRoPE rope(di,
std::integral_constant<int, HeadDim>{},
base,
rope_ti_scale,
std::integral_constant<int, kVecSize>{});
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local
Expand Down Expand Up @@ -177,6 +182,7 @@ void invokeProcessKV(void** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand Down Expand Up @@ -211,6 +217,7 @@ void invokeProcessKV(void** blocks,
cu_k_len,
cu_block_num,
rope_base,
rope_ti_scale,
stride_b,
stride_c,
stride_h,
Expand All @@ -234,6 +241,7 @@ template void invokeProcessKV(void** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand All @@ -257,6 +265,7 @@ template void invokeProcessKV(void** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand All @@ -279,6 +288,7 @@ __global__ void __launch_bounds__(128) flattenKV(T* k,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand Down Expand Up @@ -358,7 +368,11 @@ __global__ void __launch_bounds__(128) flattenKV(T* k,
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
FastRoPE rope(di, std::integral_constant<int, HeadDim>{}, base, std::integral_constant<int, kVecSize>{});
FastRoPE rope(di,
std::integral_constant<int, HeadDim>{},
base,
rope_ti_scale,
std::integral_constant<int, kVecSize>{});
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local
Expand Down Expand Up @@ -390,6 +404,7 @@ void invokeFlattenKV(T* k,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand Down Expand Up @@ -422,6 +437,7 @@ void invokeFlattenKV(T* k,
cu_k_len,
cu_block_num,
rope_base,
rope_ti_scale,
stride_b,
stride_c,
stride_h,
Expand All @@ -442,6 +458,7 @@ template void invokeFlattenKV(half* k,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand All @@ -463,6 +480,7 @@ template void invokeFlattenKV(nv_bfloat16* k,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c,
int stride_h,
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/kernels/attention/kv_cache_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void invokeProcessKV(void** blocks,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c, // cumulative len
int stride_h,
Expand All @@ -42,6 +43,7 @@ void invokeProcessKV_(const AttentionParams<T>& params)
params.cu_k_len,
params.cu_block_cnts,
params.rope_theta,
params.rope_ti_scale,
0, // stride b
params.stride / params.size_per_head, // stride c
1, // stride h
Expand All @@ -64,6 +66,7 @@ void invokeFlattenKV(T* k,
const int* cu_k_len,
const int* cu_block_num,
const float* rope_base,
float rope_ti_scale,
int stride_b,
int stride_c, // cumulative len
int stride_h,
Expand All @@ -88,6 +91,7 @@ void invokeFlattenKV_(const AttentionParams<T>& params, int sum_k_len)
params.cu_k_len,
params.cu_block_cnts,
nullptr, // params.rope_theta,
params.rope_ti_scale,
0,
1,
2 * sum_k_len,
Expand Down
6 changes: 5 additions & 1 deletion src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S, D]
cu_seq_lens.data().get(),
cu_block_cnts.data().get(),
nullptr,
1.,
2 * head_num * seq_len,
0,
seq_len,
Expand All @@ -127,6 +128,7 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S, D]
cu_seq_lens.data().get(),
cu_block_cnts.data().get(),
nullptr,
1.,
2 * head_num * seq_len,
0,
seq_len,
Expand Down Expand Up @@ -206,7 +208,7 @@ int test_attention()
constexpr size_t kSequenceLen = 0;
constexpr int kMaxSplitK = 1;

constexpr int kBlockSz = 128;
constexpr int kBlockSz = 128;

#endif

Expand Down Expand Up @@ -371,6 +373,7 @@ int test_attention()

params.rotary_embedding_dim = kHeadDim;
params.rotary_embedding_base = kRoPEBase;
params.rope_ti_scale = 1.;

params.split_cnt = split_cnt.data().get();
params.partial_L = partial_L.data().get();
Expand Down Expand Up @@ -470,6 +473,7 @@ int test_attention()
cu_kv_lens.data().get(),
cu_block_cnts.data().get(),
nullptr, // DECODING ? nullptr : params.rope_theta,
1.,
KvHeadNum * kContextLen,
0,
kContextLen,
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ struct LlamaAttentionParams {
float rotary_embedding_base;
int max_position_embeddings;
float rope_scaling_factor;
// bool use_dynamic_ntk;
bool use_logn_attn;
bool use_dynamic_ntk;
bool use_logn_attn;
};

struct EngineParams {
Expand Down

0 comments on commit 7c6773e

Please sign in to comment.