Skip to content

Commit

Permalink
Pasged attn alibi support
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 22, 2024
1 parent 2b20951 commit ee57bc4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ impl PagedAttention {
num_key_value_heads: Option<usize>,
sliding_window: Option<usize>,
device: &Device,
alibi_slopes: Option<Vec<f64>>,
alibi_slopes: Option<Vec<f32>>,
) -> Result<Self> {
let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads);
let num_queries_per_kv = num_attention_heads / num_key_value_heads;
let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
assert_eq!(alibi_slopes.len(), head_dim);
Some(Tensor::new(alibi_slopes, device)?)
} else {
None
Expand Down
24 changes: 9 additions & 15 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,10 @@ use crate::{
pipeline::text_models_inputs_processor::PagedAttentionInputMetadata,
};

const _PARTITION_SIZE: usize = 512;

#[allow(dead_code)]
pub struct PagedAttention {
num_attention_heads: usize,
head_dim: usize,
num_key_value_heads: usize,
scale: f32,
sliding_window: Option<usize>,
num_queries_per_kv: usize,
n_kv_groups: usize,
alibi_slopes: Option<Tensor>,
}

Expand All @@ -28,22 +22,20 @@ impl PagedAttention {
num_key_value_heads: Option<usize>,
sliding_window: Option<usize>,
device: &Device,
alibi_slopes: Option<Vec<f64>>,
alibi_slopes: Option<Vec<f32>>,
) -> Result<Self> {
let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads);
let num_queries_per_kv = num_attention_heads / num_key_value_heads;
let n_kv_groups = num_attention_heads / num_key_value_heads;
let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
assert_eq!(alibi_slopes.len(), head_dim);
Some(Tensor::new(alibi_slopes, device)?)
} else {
None
};
Ok(Self {
num_attention_heads,
head_dim,
num_key_value_heads,
scale,
sliding_window,
num_queries_per_kv,
n_kv_groups,
alibi_slopes,
})
}
Expand Down Expand Up @@ -81,6 +73,7 @@ impl PagedAttention {
let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
let (_, key_value_heads, _, _) = key.shape().dims4()?;

#[allow(clippy::cast_possible_truncation)]
let att = match attention_mask {
None => None,
Some(mask) => Some(Sdpa.run_attention(
Expand All @@ -90,11 +83,11 @@ impl PagedAttention {
Some(mask),
None,
&SdpaParams {
n_kv_groups: attention_heads / key_value_heads,
n_kv_groups: self.n_kv_groups,
use_flash_attn: false,
softcap: softcapping.map(|x| x as f32),
softmax_scale: self.scale,
sliding_window: None,
sliding_window: self.sliding_window,
},
)?),
};
Expand Down Expand Up @@ -159,6 +152,7 @@ impl PagedAttention {
value_cache.as_ref().unwrap(),
input_metadata.block_tables.as_ref().unwrap(),
input_metadata.context_lens.as_ref().unwrap(),
self.alibi_slopes.as_ref(),
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
Expand Down
19 changes: 19 additions & 0 deletions mistralrs-paged-attn/src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct PagedAttention {
value_cache: Tensor,
block_tables: Tensor,
context_lens: Tensor,
alibi_slopes: Option<Tensor>,
max_context_len: usize,
}

Expand Down Expand Up @@ -101,6 +102,19 @@ impl PagedAttention {
let cl = cl.slice(cl_l.start_offset()..);
let bt = bt.slice(bt_l.start_offset()..);

let alibi_s_ptr = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
let (alibi_s, alibi_s_l) = alibi_slopes.storage_and_layout();
let alibi_s = match &*alibi_s {
Storage::Cuda(alibi_s) => alibi_s,
_ => candle::bail!("context_lens must be a cuda tensor"),
};
let alibi_s = alibi_s.as_cuda_slice::<f32>()?;
let alibi_s = alibi_s.slice(alibi_s_l.start_offset()..);
*alibi_s.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};

let (num_seqs, num_heads, head_size) = q_l.shape().dims3()?;
if !(head_size == 64
|| head_size == 80
Expand Down Expand Up @@ -173,6 +187,7 @@ impl PagedAttention {
q_ptr,
kc_ptr,
vc_ptr,
alibi_s_ptr,
num_kv_heads as c_int,
self.softmax_scale,
self.softcapping,
Expand Down Expand Up @@ -210,6 +225,7 @@ impl PagedAttention {
q_ptr,
kc_ptr,
vc_ptr,
alibi_s_ptr,
num_kv_heads as c_int,
self.softmax_scale,
self.softcapping,
Expand Down Expand Up @@ -270,6 +286,7 @@ impl candle::CustomOp1 for PagedAttention {
/// * `max_context_len` - Max of `context_len`
/// * `softmax_scale` - scaling factor
/// * `softcapping`- Softcapping value as in Gemma 2. Using 1.0 means do nothing.
/// * `alibi_slopes`- Optional alibi slopes, `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(num_sequences, num_heads_q, head_size)`.
#[allow(clippy::too_many_arguments)]
Expand All @@ -279,6 +296,7 @@ pub fn paged_attention(
value_cache: &Tensor,
block_tables: &Tensor,
context_lens: &Tensor,
alibi_slopes: Option<&Tensor>,
max_context_len: usize,
softmax_scale: f32,
softcapping: f32,
Expand All @@ -291,6 +309,7 @@ pub fn paged_attention(
context_lens: context_lens.clone(),
max_context_len,
softcapping,
alibi_slopes: alibi_slopes.cloned(),
};
q.apply_op1(op)
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-paged-attn/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern "C" {
query: *const c_void,
key_cache: *const c_void,
value_cache: *const c_void,
alibi_slopes: *const c_void,
num_kv_heads: c_int,
scale: f32,
softcapping: f32,
Expand Down Expand Up @@ -51,6 +52,7 @@ extern "C" {
query: *const c_void,
key_cache: *const c_void,
value_cache: *const c_void,
alibi_slopes: *const c_void,
num_kv_heads: c_int,
scale: f32,
softcapping: f32,
Expand Down
20 changes: 12 additions & 8 deletions mistralrs-paged-attn/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ __global__ void paged_attention_v2_reduce_kernel(
block_tables, \
context_lens, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
reinterpret_cast<float*>(alibi_slopes), \
q_stride, \
kv_block_stride, \
kv_head_stride);
Expand All @@ -600,6 +600,7 @@ void paged_attention_v1_launcher(
void *query,
void *key_cache,
void *value_cache,
void* __restrict__ alibi_slopes,
int num_kv_heads,
float scale,
float softcapping,
Expand All @@ -619,8 +620,7 @@ void paged_attention_v1_launcher(
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
// assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = nullptr;
// NOTE: alibi_slopes is optional. It may be nullptr.

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
Expand Down Expand Up @@ -666,6 +666,7 @@ void paged_attention_v1_launcher(
query, \
key_cache, \
value_cache, \
alibi_slopes, \
num_kv_heads, \
scale, \
softcapping, \
Expand Down Expand Up @@ -702,7 +703,8 @@ extern "C" void paged_attention_v1(
void *query, // [num_seqs, num_heads, head_size]
void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
void *value_cache, // [num_blocks, num_heads, head_size, block_size]
int32_t num_kv_heads, // [num_heads]
void *alibi_slopes, // [num_heads]
int32_t num_kv_heads,
float scale,
float softcapping,
uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq]
Expand Down Expand Up @@ -740,11 +742,11 @@ extern "C" void paged_attention_v1(
reinterpret_cast<T*>(value_cache), \
num_kv_heads, \
scale, \
softcapping, \
softcapping, \
block_tables, \
context_lens, \
max_num_blocks_per_seq, \
alibi_slopes, \
reinterpret_cast<float*>(alibi_slopes), \
q_stride, \
kv_block_stride, \
kv_head_stride); \
Expand All @@ -770,6 +772,7 @@ void paged_attention_v2_launcher(
void *query,
void *key_cache,
void *value_cache,
void *alibi_slopes,
int num_kv_heads,
float scale,
float softcapping,
Expand All @@ -788,8 +791,7 @@ void paged_attention_v2_launcher(
) {
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes = nullptr;
// NOTE: alibi_slopes is optional. It may be nullptr.

T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out);

Expand Down Expand Up @@ -843,6 +845,7 @@ void paged_attention_v2_launcher(
query, \
key_cache, \
value_cache, \
alibi_slopes, \
num_kv_heads, \
scale, \
softcapping, \
Expand Down Expand Up @@ -882,6 +885,7 @@ extern "C" void paged_attention_v2(
void *query, // [num_seqs, num_heads, head_size]
void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
void *value_cache, // [num_blocks, num_heads, head_size, block_size]
void *alibi_slopes, // [num_heads]
int32_t num_kv_heads,
float scale,
float softcapping,
Expand Down

0 comments on commit ee57bc4

Please sign in to comment.