From 7c50b68bb6423cebddc03bc0f1c374bfe1f1dc2b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:51:47 -0400 Subject: [PATCH] Include block engine in paged attn metadata (#576) Helps with Phi3V --- .../src/dummy_paged_attention/scheduler.rs | 3 ++ mistralrs-core/src/engine/mod.rs | 3 +- .../src/paged_attention/scheduler.rs | 3 ++ .../src/pipeline/inputs_processor.rs | 33 ++++++++++--------- .../src/scheduler/default_scheduler.rs | 5 ++- mistralrs-core/src/scheduler/mod.rs | 5 +-- .../vision_models/idefics2_input_processor.rs | 6 ++-- .../llava/llava_inputs_processor.rs | 11 +++++-- .../llava/llava_next_inputs_processor.rs | 11 +++++-- .../vision_models/phi3_inputs_processor.rs | 11 +++++-- 10 files changed, 59 insertions(+), 32 deletions(-) diff --git a/mistralrs-core/src/dummy_paged_attention/scheduler.rs b/mistralrs-core/src/dummy_paged_attention/scheduler.rs index 551c69516..4b19a250a 100644 --- a/mistralrs-core/src/dummy_paged_attention/scheduler.rs +++ b/mistralrs-core/src/dummy_paged_attention/scheduler.rs @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler { fn free_finished_sequence_groups(&mut self) { self.free_finished_sequence_groups() } + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + Some(&mut self.block_engine) + } } diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 36cc9b9a9..ee2597c64 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -271,13 +271,12 @@ impl Engine { let res = { let mut pipeline = get_mut_arcmutex!(self.pipeline); - let block_tables = self.scheduler.block_tables().unwrap(); let block_size = self.scheduler.block_size().unwrap(); let metadata = PagedAttentionMeta { - block_tables, block_size, sliding_window: pipeline.get_metadata().sliding_window, + block_engine: self.scheduler.block_engine().unwrap(), }; pipeline diff --git a/mistralrs-core/src/paged_attention/scheduler.rs b/mistralrs-core/src/paged_attention/scheduler.rs index 551c69516..4b19a250a 100644 --- a/mistralrs-core/src/paged_attention/scheduler.rs +++ b/mistralrs-core/src/paged_attention/scheduler.rs @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler { fn free_finished_sequence_groups(&mut self) { self.free_finished_sequence_groups() } + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + Some(&mut self.block_engine) + } } diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index d70010117..e3675696a 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -47,7 +47,7 @@ pub mod text_models_inputs_processor { use crate::{ layers::set_use_matmul_via_f16, - paged_attention::{BlockTables, _PAD_SLOT_ID}, + paged_attention::{BlockEngine, _PAD_SLOT_ID}, sequence::Sequence, }; @@ -71,11 +71,10 @@ pub mod text_models_inputs_processor { Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg) } - #[derive(Clone)] pub struct PagedAttentionMeta<'a> { - pub block_tables: &'a BlockTables, pub sliding_window: Option, pub block_size: usize, + pub block_engine: &'a mut BlockEngine, } #[derive(Clone, Debug)] @@ -101,7 +100,7 @@ pub mod text_models_inputs_processor { input_seqs: &[&mut Sequence], device: &Device, last_n_context_len: Option<(usize, usize)>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, ) -> Result { let max_len = input_seqs .iter() @@ -128,8 +127,8 @@ pub mod text_models_inputs_processor { seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap()); - if let Some(ref paged_attn_metadata) = paged_attn_metadata { - let table = paged_attn_metadata.block_tables.get(seq.id()); + if let Some(paged_attn_metadata) = &mut paged_attn_metadata { + let table = paged_attn_metadata.block_engine.block_tables.get(seq.id()); let prompt_len = seq.len(); if table.is_none() { @@ -240,7 +239,7 @@ pub mod text_models_inputs_processor { device: &Device, no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, ) -> Result { if no_kv_cache { return get_prompt_input( @@ -269,8 +268,12 @@ pub mod text_models_inputs_processor { seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap()); - if let Some(ref paged_attn_metadata) = paged_attn_metadata { - let table = paged_attn_metadata.block_tables.get(seq.id()).unwrap(); + if let Some(paged_attn_metadata) = &mut paged_attn_metadata { + let table = paged_attn_metadata + .block_engine + .block_tables + .get(seq.id()) + .unwrap(); let table = table .iter() @@ -393,7 +396,7 @@ pub mod text_models_inputs_processor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, _: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> Result> { if is_xlora && !is_prompt { let InputMetadata { @@ -411,7 +414,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata.clone(), + paged_attn_metadata.as_mut(), )?; let InputMetadata { input: input_ids, @@ -429,7 +432,7 @@ pub mod text_models_inputs_processor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, @@ -458,7 +461,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids: input_ids.clone(), @@ -487,7 +490,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, @@ -517,7 +520,7 @@ pub mod text_models_inputs_processor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, diff --git a/mistralrs-core/src/scheduler/default_scheduler.rs b/mistralrs-core/src/scheduler/default_scheduler.rs index c20eb9120..f36599b9a 100644 --- a/mistralrs-core/src/scheduler/default_scheduler.rs +++ b/mistralrs-core/src/scheduler/default_scheduler.rs @@ -6,7 +6,7 @@ use std::{ use crate::{ engine::TERMINATE_ALL_NEXT_STEP, - paged_attention::BlockTables, + paged_attention::{BlockEngine, BlockTables}, sequence::{Sequence, SequenceState, StopReason}, }; @@ -323,4 +323,7 @@ impl Scheduler for DefaultScheduler> { None } fn free_finished_sequence_groups(&mut self) {} + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + None + } } diff --git a/mistralrs-core/src/scheduler/mod.rs b/mistralrs-core/src/scheduler/mod.rs index dcfc41d8a..25aff8a07 100644 --- a/mistralrs-core/src/scheduler/mod.rs +++ b/mistralrs-core/src/scheduler/mod.rs @@ -4,8 +4,8 @@ pub use default_scheduler::{DefaultScheduler, DefaultSchedulerMethod, DefaultSch use crate::{ paged_attention::{ - BlockTables, CacheConfig, PagedAttentionScheduler, PagedAttentionSchedulerConfig, - PagedAttentionSchedulerOutput, + BlockEngine, BlockTables, CacheConfig, PagedAttentionScheduler, + PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, }, sequence::Sequence, }; @@ -55,4 +55,5 @@ pub trait Scheduler { // Paged Attention metadata fn block_tables(&self) -> Option<&BlockTables>; fn block_size(&self) -> Option; + fn block_engine(&mut self) -> Option<&mut BlockEngine>; } diff --git a/mistralrs-core/src/vision_models/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2_input_processor.rs index 30ee9031a..7318835ea 100644 --- a/mistralrs-core/src/vision_models/idefics2_input_processor.rs +++ b/mistralrs-core/src/vision_models/idefics2_input_processor.rs @@ -118,7 +118,7 @@ impl InputsProcessor for Idefics2ImageProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -142,7 +142,7 @@ impl InputsProcessor for Idefics2ImageProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -154,7 +154,7 @@ impl InputsProcessor for Idefics2ImageProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; let config = other_config.expect("Need a PreProcessorConfig config."); diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs index 0aaa1ae4e..4bdaa94cb 100644 --- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs @@ -81,7 +81,7 @@ impl InputsProcessor for LLaVAInputProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -212,6 +212,11 @@ impl InputsProcessor for LLaVAInputProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -229,7 +234,7 @@ impl InputsProcessor for LLaVAInputProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -238,7 +243,7 @@ impl InputsProcessor for LLaVAInputProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; Ok(Box::new(ModelInputs { diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs index f35304db6..0568ddfd4 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs @@ -88,7 +88,7 @@ impl InputsProcessor for LLaVANextInputProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -256,6 +256,11 @@ impl InputsProcessor for LLaVANextInputProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -273,7 +278,7 @@ impl InputsProcessor for LLaVANextInputProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -282,7 +287,7 @@ impl InputsProcessor for LLaVANextInputProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; Ok(Box::new(ModelInputs { diff --git a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs index e0a7ffd63..83ee5785a 100644 --- a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs @@ -76,7 +76,7 @@ impl InputsProcessor for Phi3InputsProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -245,6 +245,11 @@ impl InputsProcessor for Phi3InputsProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -262,7 +267,7 @@ impl InputsProcessor for Phi3InputsProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -271,7 +276,7 @@ impl InputsProcessor for Phi3InputsProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? };