Skip to content

Commit

Permalink
Include block engine in paged attn metadata (#576)
Browse files Browse the repository at this point in the history
Helps with Phi3V
  • Loading branch information
EricLBuehler authored Jul 15, 2024
1 parent ebe032e commit 7c50b68
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 32 deletions.
3 changes: 3 additions & 0 deletions mistralrs-core/src/dummy_paged_attention/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler {
fn free_finished_sequence_groups(&mut self) {
self.free_finished_sequence_groups()
}
fn block_engine(&mut self) -> Option<&mut BlockEngine> {
Some(&mut self.block_engine)
}
}
3 changes: 1 addition & 2 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,12 @@ impl Engine {
let res = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);

let block_tables = self.scheduler.block_tables().unwrap();
let block_size = self.scheduler.block_size().unwrap();

let metadata = PagedAttentionMeta {
block_tables,
block_size,
sliding_window: pipeline.get_metadata().sliding_window,
block_engine: self.scheduler.block_engine().unwrap(),
};

pipeline
Expand Down
3 changes: 3 additions & 0 deletions mistralrs-core/src/paged_attention/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler {
fn free_finished_sequence_groups(&mut self) {
self.free_finished_sequence_groups()
}
fn block_engine(&mut self) -> Option<&mut BlockEngine> {
Some(&mut self.block_engine)
}
}
33 changes: 18 additions & 15 deletions mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub mod text_models_inputs_processor {

use crate::{
layers::set_use_matmul_via_f16,
paged_attention::{BlockTables, _PAD_SLOT_ID},
paged_attention::{BlockEngine, _PAD_SLOT_ID},
sequence::Sequence,
};

Expand All @@ -71,11 +71,10 @@ pub mod text_models_inputs_processor {
Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
}

#[derive(Clone)]
pub struct PagedAttentionMeta<'a> {
pub block_tables: &'a BlockTables,
pub sliding_window: Option<usize>,
pub block_size: usize,
pub block_engine: &'a mut BlockEngine,
}

#[derive(Clone, Debug)]
Expand All @@ -101,7 +100,7 @@ pub mod text_models_inputs_processor {
input_seqs: &[&mut Sequence],
device: &Device,
last_n_context_len: Option<(usize, usize)>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
) -> Result<InputMetadata> {
let max_len = input_seqs
.iter()
Expand All @@ -128,8 +127,8 @@ pub mod text_models_inputs_processor {

seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());

if let Some(ref paged_attn_metadata) = paged_attn_metadata {
let table = paged_attn_metadata.block_tables.get(seq.id());
if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
let table = paged_attn_metadata.block_engine.block_tables.get(seq.id());
let prompt_len = seq.len();

if table.is_none() {
Expand Down Expand Up @@ -240,7 +239,7 @@ pub mod text_models_inputs_processor {
device: &Device,
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>,
) -> Result<InputMetadata> {
if no_kv_cache {
return get_prompt_input(
Expand Down Expand Up @@ -269,8 +268,12 @@ pub mod text_models_inputs_processor {

seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());

if let Some(ref paged_attn_metadata) = paged_attn_metadata {
let table = paged_attn_metadata.block_tables.get(seq.id()).unwrap();
if let Some(paged_attn_metadata) = &mut paged_attn_metadata {
let table = paged_attn_metadata
.block_engine
.block_tables
.get(seq.id())
.unwrap();

let table = table
.iter()
Expand Down Expand Up @@ -393,7 +396,7 @@ pub mod text_models_inputs_processor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
_: Option<Arc<dyn Any>>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
) -> Result<Box<dyn Any>> {
if is_xlora && !is_prompt {
let InputMetadata {
Expand All @@ -411,7 +414,7 @@ pub mod text_models_inputs_processor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata.clone(),
paged_attn_metadata.as_mut(),
)?;
let InputMetadata {
input: input_ids,
Expand All @@ -429,7 +432,7 @@ pub mod text_models_inputs_processor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?;
Ok(Box::new(ModelInputs {
input_ids,
Expand Down Expand Up @@ -458,7 +461,7 @@ pub mod text_models_inputs_processor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?;
Ok(Box::new(ModelInputs {
input_ids: input_ids.clone(),
Expand Down Expand Up @@ -487,7 +490,7 @@ pub mod text_models_inputs_processor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?;
Ok(Box::new(ModelInputs {
input_ids,
Expand Down Expand Up @@ -517,7 +520,7 @@ pub mod text_models_inputs_processor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?;
Ok(Box::new(ModelInputs {
input_ids,
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/scheduler/default_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use crate::{
engine::TERMINATE_ALL_NEXT_STEP,
paged_attention::BlockTables,
paged_attention::{BlockEngine, BlockTables},
sequence::{Sequence, SequenceState, StopReason},
};

Expand Down Expand Up @@ -323,4 +323,7 @@ impl Scheduler for DefaultScheduler<VecDeque<Sequence>> {
None
}
fn free_finished_sequence_groups(&mut self) {}
fn block_engine(&mut self) -> Option<&mut BlockEngine> {
None
}
}
5 changes: 3 additions & 2 deletions mistralrs-core/src/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ pub use default_scheduler::{DefaultScheduler, DefaultSchedulerMethod, DefaultSch

use crate::{
paged_attention::{
BlockTables, CacheConfig, PagedAttentionScheduler, PagedAttentionSchedulerConfig,
PagedAttentionSchedulerOutput,
BlockEngine, BlockTables, CacheConfig, PagedAttentionScheduler,
PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
},
sequence::Sequence,
};
Expand Down Expand Up @@ -55,4 +55,5 @@ pub trait Scheduler {
// Paged Attention metadata
fn block_tables(&self) -> Option<&BlockTables>;
fn block_size(&self) -> Option<usize>;
fn block_engine(&mut self) -> Option<&mut BlockEngine>;
}
6 changes: 3 additions & 3 deletions mistralrs-core/src/vision_models/idefics2_input_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl InputsProcessor for Idefics2ImageProcessor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
other_config: Option<Arc<dyn Any>>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
) -> anyhow::Result<Box<dyn std::any::Any>> {
if is_xlora {
anyhow::bail!("Cannot make inputs for X-LoRA vision model.");
Expand All @@ -142,7 +142,7 @@ impl InputsProcessor for Idefics2ImageProcessor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
} else {
get_completion_input(
Expand All @@ -154,7 +154,7 @@ impl InputsProcessor for Idefics2ImageProcessor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
};
let config = other_config.expect("Need a PreProcessorConfig config.");
Expand Down
11 changes: 8 additions & 3 deletions mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl InputsProcessor for LLaVAInputProcessor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
other_config: Option<Arc<dyn Any>>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
) -> anyhow::Result<Box<dyn Any>> {
if is_xlora {
anyhow::bail!("Cannot make inputs for X-LoRA vision model.");
Expand Down Expand Up @@ -212,6 +212,11 @@ impl InputsProcessor for LLaVAInputProcessor {
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
.collect::<Vec<_>>(),
);
if let Some(ref mut metadata) = paged_attn_metadata {
// Free and then reallocate as appropriate
metadata.block_engine.free_sequence(*seq.id());
metadata.block_engine.allocate(*seq);
}

toks.push(input_ids);
}
Expand All @@ -229,7 +234,7 @@ impl InputsProcessor for LLaVAInputProcessor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
} else {
get_completion_input(
Expand All @@ -238,7 +243,7 @@ impl InputsProcessor for LLaVAInputProcessor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
};
Ok(Box::new(ModelInputs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl InputsProcessor for LLaVANextInputProcessor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
other_config: Option<Arc<dyn Any>>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
) -> anyhow::Result<Box<dyn Any>> {
if is_xlora {
anyhow::bail!("Cannot make inputs for X-LoRA vision model.");
Expand Down Expand Up @@ -256,6 +256,11 @@ impl InputsProcessor for LLaVANextInputProcessor {
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
.collect::<Vec<_>>(),
);
if let Some(ref mut metadata) = paged_attn_metadata {
// Free and then reallocate as appropriate
metadata.block_engine.free_sequence(*seq.id());
metadata.block_engine.allocate(*seq);
}

toks.push(input_ids);
}
Expand All @@ -273,7 +278,7 @@ impl InputsProcessor for LLaVANextInputProcessor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
} else {
get_completion_input(
Expand All @@ -282,7 +287,7 @@ impl InputsProcessor for LLaVANextInputProcessor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
};
Ok(Box::new(ModelInputs {
Expand Down
11 changes: 8 additions & 3 deletions mistralrs-core/src/vision_models/phi3_inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl InputsProcessor for Phi3InputsProcessor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
other_config: Option<Arc<dyn Any>>,
paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
) -> anyhow::Result<Box<dyn Any>> {
if is_xlora {
anyhow::bail!("Cannot make inputs for X-LoRA vision model.");
Expand Down Expand Up @@ -245,6 +245,11 @@ impl InputsProcessor for Phi3InputsProcessor {
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
.collect::<Vec<_>>(),
);
if let Some(ref mut metadata) = paged_attn_metadata {
// Free and then reallocate as appropriate
metadata.block_engine.free_sequence(*seq.id());
metadata.block_engine.allocate(*seq);
}

toks.push(input_ids);
}
Expand All @@ -262,7 +267,7 @@ impl InputsProcessor for Phi3InputsProcessor {
input_seqs,
device,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
} else {
get_completion_input(
Expand All @@ -271,7 +276,7 @@ impl InputsProcessor for Phi3InputsProcessor {
device,
no_kv_cache,
last_n_context_len,
paged_attn_metadata,
paged_attn_metadata.as_mut(),
)?
};

Expand Down

0 comments on commit 7c50b68

Please sign in to comment.