From 639268c73a68e81040e872449050372cdc469455 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 14 Jul 2024 13:46:33 -0400 Subject: [PATCH 1/3] Remove ensure about no paged attn for vision models (#573) --- mistralrs-core/src/pipeline/vision.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index 3e08e87356..fcfe162abb 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -184,11 +184,6 @@ impl Loader for VisionLoader { Device::Cpu }; - anyhow::ensure!( - paged_attn_config.is_none(), - "PagedAttention is not supported for vision models" - ); - let attention_mechanism = if paged_attn_config.is_some() { AttentionImplementation::PagedAttention } else { From ebe032e8991320f973c5ae89270bf760dadde380 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 14 Jul 2024 22:05:10 -0400 Subject: [PATCH 2/3] Add percentage utilization support to paged attn (#574) * Support MISTRALRS_DEBUG=1 in paged attn * Try to handle deadlock * Try to fix deadlock again * Add percentage utilization support * Update docs * Clippy --- docs/PAGED_ATTENTION.md | 6 +- mistralrs-bench/src/main.rs | 37 ++++++-- .../src/dummy_paged_attention/mod.rs | 25 +++--- mistralrs-core/src/engine/mod.rs | 84 +++++++++++++------ mistralrs-core/src/paged_attention/mod.rs | 25 +++--- mistralrs-core/src/utils/memory_usage.rs | 24 ++++++ mistralrs-pyo3/mistralrs.pyi | 6 +- mistralrs-pyo3/src/lib.rs | 10 ++- mistralrs-server/src/main.rs | 37 ++++++-- mistralrs/examples/paged_attn/main.rs | 12 ++- 10 files changed, 190 insertions(+), 76 deletions(-) diff --git a/docs/PAGED_ATTENTION.md b/docs/PAGED_ATTENTION.md index 421f4df99d..db716a68f9 100644 --- a/docs/PAGED_ATTENTION.md +++ b/docs/PAGED_ATTENTION.md @@ -6,7 +6,7 @@ Our Paged Attention implementation has 2 inputs: GPU KV cache memory size, and b > Note: The default block size if not specified is 32. -> Warning: When using dynamic adapter activation or sending re-ISQ requests, it may trigger OOM because the Paged Attention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount (recommended) or disable paged attention. +> Note: if OOM happens (this can be caused by a variety of factors including adapter activation, re-ISQ, and others), it happens because the Paged Attention KV cache has already been allocated. To counter this, either set the KV cache memory to a lower amount or usage percentage (recommended) or disable paged attention entirely for a dynamically allocated cache. **There are more features being added to this:** - GGML model support @@ -23,14 +23,14 @@ Our Paged Attention implementation has 2 inputs: GPU KV cache memory size, and b ## Using the CLI -Add the `--pa-gpu-mem` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type. +Add the `--pa-gpu-mem`/`--pa-gpu-mem-usage` and `--pa-blk-size` parameters before the model kind selector. The GPU memory is in MBs and the block size means the number of tokens per block. These parameters may be passed on any supported model type. ``` cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 --isq Q4K plain -m microsoft/Phi-3-mini-128k-instruct -a phi3 ``` ``` -cargo run --release --features cuda -- -i --pa-gpu-mem 8192 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf +cargo run --release --features cuda -- -i --pa-gpu-mem-usage .95 --pa-blk-size 32 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf ``` ## Using the Rust API diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 985958ce5c..bb37c4f71a 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -1,6 +1,7 @@ use candle_core::Device; use clap::Parser; use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table}; +use either::Either; use mistralrs_core::{ initialize_logging, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelDType, @@ -278,11 +279,17 @@ struct Args { #[arg(short, long, value_parser, value_delimiter = ';')] num_device_layers: Option>, - /// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to to the - /// available GPU memory. Paged Attention is only supported on CUDA and is always automatically activated. + /// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to + /// using `pa-gpu-mem-usage` set to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated. #[arg(long = "pa-gpu-mem")] paged_attn_gpu_mem: Option, + /// Percentage of GPU memory to utilize after allocation of KV cache with Paged Attention, from 0 to 1. + /// If this is not set and the device is CUDA, it will default to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated. + /// This is always used over `pa-gpu-mem` if both are specified. + #[arg(long = "pa-gpu-mem-usage")] + paged_attn_gpu_mem_usage: Option, + /// Block size (number of tokens per block) for Paged Attention. If this is not set and the device is CUDA, it will default to 32. /// Paged Attention is only supported on CUDA and is always automatically activated. #[arg(long = "pa-blk-size")] @@ -373,16 +380,32 @@ fn main() -> anyhow::Result<()> { let cache_config = match ( args.paged_attn_block_size, args.paged_attn_gpu_mem, + args.paged_attn_gpu_mem_usage, device.is_cuda(), args.no_paged_attn, ) { - (block_size, None, true, false) => Some(PagedAttentionConfig::new( - block_size, 512, None, // Autodetermine KV cache size + (block_size, None, None, true, false) => Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory + )?), + (block_size, Some(m), None, true, false) => { + Some(PagedAttentionConfig::new(block_size, 512, Either::Left(m))?) + } + (block_size, None, Some(f), true, false) => Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(f), )?), - (block_size, Some(gpu_mem), _, false) => { - Some(PagedAttentionConfig::new(block_size, 512, Some(gpu_mem))?) + (block_size, Some(_m), Some(f), true, false) => { + info!("Both memory size and usage were specified, defaulting to the usage value."); + Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(f), + )?) } - (_, _, _, _) => None, + (_, _, _, _, _) => None, }; let pipeline = loader.load_model_from_hf( diff --git a/mistralrs-core/src/dummy_paged_attention/mod.rs b/mistralrs-core/src/dummy_paged_attention/mod.rs index ea7bf36e25..1f97cd5aed 100644 --- a/mistralrs-core/src/dummy_paged_attention/mod.rs +++ b/mistralrs-core/src/dummy_paged_attention/mod.rs @@ -16,6 +16,7 @@ pub use block_engine_sequence::BlockEngineSequence; pub use cache_engine::{CacheConfig, CacheEngine}; use candle_core::{DType, Device}; pub use config::{ModelConfigLike, ModelConfigMetadata}; +use either::Either; pub use layers::PagedAttention; pub use scheduler::{ PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, @@ -29,14 +30,14 @@ use tracing::info; pub struct PagedAttentionConfig { pub(crate) block_size: Option, pub(crate) mem_cpu: usize, - pub(crate) mem_gpu: Option, + pub(crate) mem_gpu: Either, } impl PagedAttentionConfig { pub fn new( _block_size: Option, _mem_cpu: usize, - _mem_gpu: Option, + _mem_gpu: Either, ) -> anyhow::Result { anyhow::bail!("PagedAttention is only supported for CUDA, compile with feature `cuda`.") } @@ -64,9 +65,9 @@ macro_rules! mb_to_blocks { }; } -/// Memory values are in MBs. Specify block size or the default is 32. +/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32. pub fn calculate_cache_config( - mem_gpu: Option, + mem_gpu: Either, mem_cpu: usize, block_size: Option, dtype: DType, @@ -79,15 +80,15 @@ pub fn calculate_cache_config( } let dtype_size = dtype.size_in_bytes(); + #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] let mem_gpu = match mem_gpu { - Some(v) => v, - None => { - let free = MemoryUsage.get_memory_available(device)? / SIZE_IN_MB; - info!( - "Automatically using {} MB for Paged Attention KV cache", - free - 512 - ); - free - 512 + Either::Left(v) => v, + Either::Right(f) => { + let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32; + let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32 * f; + let size = (total - free) as usize; + info!("Allocating {size} MB for Paged Attention KV cache"); + size } }; diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index a9b2a89735..36cc9b9a97 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -257,50 +257,80 @@ impl Engine { } SchedulerOutput::PagedAttention { mut output } => { if !output.scheduled.is_empty() { - let mut pipeline = get_mut_arcmutex!(self.pipeline); - let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt(); - let block_tables = self.scheduler.block_tables().unwrap(); - let block_size = self.scheduler.block_size().unwrap(); - - let metadata = PagedAttentionMeta { - block_tables, - block_size, - sliding_window: pipeline.get_metadata().sliding_window, - }; - let mut guards = output .scheduled .iter_mut() .map(|seq| seq.lock().unwrap()) .collect::>(); - let res = pipeline - .step( - &mut guards.iter_mut().map(|seq| &mut **seq).collect::>(), - is_prompt, - &mut self.prefix_cacher, - self.disable_eos_stop, - rng.clone(), - CacheBackendMetadata::PagedAttention { - metadata, - blocks_to_copy: output.blocks_to_copy, - blocks_to_swap_in: output.blocks_to_swap_in, - blocks_to_swap_out: output.blocks_to_swap_out, - }, - ) - .await; + let mut guards_mut = + guards.iter_mut().map(|seq| &mut **seq).collect::>(); + + let res = { + let mut pipeline = get_mut_arcmutex!(self.pipeline); + + let block_tables = self.scheduler.block_tables().unwrap(); + let block_size = self.scheduler.block_size().unwrap(); + + let metadata = PagedAttentionMeta { + block_tables, + block_size, + sliding_window: pipeline.get_metadata().sliding_window, + }; + + pipeline + .step( + &mut guards_mut, + is_prompt, + &mut self.prefix_cacher, + self.disable_eos_stop, + rng.clone(), + CacheBackendMetadata::PagedAttention { + metadata, + blocks_to_copy: output.blocks_to_copy, + blocks_to_swap_in: output.blocks_to_swap_in, + blocks_to_swap_out: output.blocks_to_swap_out, + }, + ) + .await + }; handle_pipeline_forward_error!( "step", res, - &mut guards.iter_mut().map(|seq| &mut **seq).collect::>(), + &mut guards_mut, self.pipeline, 'lp, self.prefix_cacher ); + if self.is_debug { + let ms_from_last_run = run_start.elapsed().as_secs_f64(); + let total_len = guards.len(); + if total_len > 0 { + let lengths = guards + .iter() + .map(|seq| seq.len().to_string()) + .collect::>() + .join(", "); + + let (prompt_lengths, completion_lengths) = if is_prompt { + (lengths, "".to_string()) + } else { + ("".to_string(), lengths) + }; + + tracing::info!( + "Prompt[{}] Completion[{}] - {}ms", + prompt_lengths, + completion_lengths, + ms_from_last_run * 1000., + ); + } + } + if is_prompt { for mut seq in guards { let now = SystemTime::now() diff --git a/mistralrs-core/src/paged_attention/mod.rs b/mistralrs-core/src/paged_attention/mod.rs index 6c6fdbc23c..e9e6cca380 100644 --- a/mistralrs-core/src/paged_attention/mod.rs +++ b/mistralrs-core/src/paged_attention/mod.rs @@ -16,6 +16,7 @@ pub use block_engine_sequence::BlockEngineSequence; pub use cache_engine::{CacheConfig, CacheEngine}; use candle_core::{DType, Device}; pub use config::{ModelConfigLike, ModelConfigMetadata}; +use either::Either; pub use layers::PagedAttention; pub use scheduler::{ PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, @@ -29,14 +30,14 @@ use tracing::info; pub struct PagedAttentionConfig { pub(crate) block_size: Option, pub(crate) mem_cpu: usize, - pub(crate) mem_gpu: Option, + pub(crate) mem_gpu: Either, } impl PagedAttentionConfig { pub fn new( block_size: Option, mem_cpu: usize, - mem_gpu: Option, + mem_gpu: Either, ) -> anyhow::Result { Ok(Self { block_size, @@ -68,9 +69,9 @@ macro_rules! mb_to_blocks { }; } -/// Memory values are in MBs. Specify block size or the default is 32. +/// Memory values are in MBs or a percentage in [0,1]. Specify block size or the default is 32. pub fn calculate_cache_config( - mem_gpu: Option, + mem_gpu: Either, mem_cpu: usize, block_size: Option, dtype: DType, @@ -83,15 +84,15 @@ pub fn calculate_cache_config( } let dtype_size = dtype.size_in_bytes(); + #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] let mem_gpu = match mem_gpu { - Some(v) => v, - None => { - let free = MemoryUsage.get_memory_available(device)? / SIZE_IN_MB; - info!( - "Automatically using {} MB for Paged Attention KV cache", - free - 512 - ); - free - 512 + Either::Left(v) => v, + Either::Right(f) => { + let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32; + let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32 * f; + let size = (total - free) as usize; + info!("Allocating {size} MB for Paged Attention KV cache"); + size } }; diff --git a/mistralrs-core/src/utils/memory_usage.rs b/mistralrs-core/src/utils/memory_usage.rs index a3a620925c..611cd8ee2c 100644 --- a/mistralrs-core/src/utils/memory_usage.rs +++ b/mistralrs-core/src/utils/memory_usage.rs @@ -29,4 +29,28 @@ impl MemoryUsage { } } } + + pub fn get_total_memory(&self, device: &Device) -> Result { + match device { + Device::Cpu => { + let mut sys = System::new_all(); + sys.refresh_cpu(); + Ok(usize::try_from(sys.total_memory())? * KB_TO_BYTES) + } + #[cfg(feature = "cuda")] + Device::Cuda(_) => { + use candle_core::cuda_backend::WrapErr; + Ok(candle_core::cuda::cudarc::driver::result::mem_get_info() + .w()? + .1) + } + #[cfg(not(feature = "cuda"))] + Device::Cuda(_) => { + candle_core::bail!("Cannot get total memory for CUDA device") + } + Device::Metal(_) => { + candle_core::bail!("Cannot get total memory for Metal device") + } + } + } } diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index 39e4881937..82f51b1b6a 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -186,7 +186,7 @@ class Runner: num_device_layers: list[str] | None = None, in_situ_quant: str | None = None, anymoe_config: AnyMoeConfig | None = None, - pa_gpu_mem: int | None = None, + pa_gpu_mem: int | float | None = None, pa_blk_size: int | None = None, no_paged_attn: bool = False, ) -> None: @@ -211,8 +211,8 @@ class Runner: the corresponding number of layers. - `in_situ_quant` sets the optional in-situ quantization for models that are not quantized (not GGUF or GGML). - `anymoe_config` specifies the AnyMoE config. If this is set, then the model will be loaded as an AnyMoE model. - - `pa_gpu_mem` sets GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is - CUDA, it will default to to the available GPU memory. Paged Attention is only supported on CUDA and is always automatically activated. + - `pa_gpu_mem` sets GPU memory to allocate for KV cache with Paged Attention in MBs *OR* the percentage utilization, from 0 to 1. If this is not set and the device is + CUDA, it will default to using 90% of the total memory after allocation of the KV cache. Paged Attention is only supported on CUDA and is always automatically activated. - `pa_blk_size` sets the block size (number of tokens per block) for Paged Attention. If this is not set and the device is CUDA, it will default to 32. Paged Attention is only supported on CUDA and is always automatically activated. - `no_paged_attn` disables Paged Attention on CUDA diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 23aae8ea2c..65a5bff61a 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -351,7 +351,7 @@ impl Runner { num_device_layers: Option>, in_situ_quant: Option, anymoe_config: Option, - pa_gpu_mem: Option, + pa_gpu_mem: Option>, pa_blk_size: Option, no_paged_attn: bool, ) -> PyResult { @@ -466,10 +466,12 @@ impl Runner { // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`. let cache_config = match (pa_blk_size, pa_gpu_mem, device.is_cuda(), no_paged_attn) { (block_size, None, true, false) => Some(PagedAttentionConfig::new( - block_size, 512, None, // Autodetermine KV cache size + block_size, + 512, + Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory )?), - (block_size, Some(gpu_mem), _, false) => { - Some(PagedAttentionConfig::new(block_size, 512, Some(gpu_mem))?) + (block_size, Some(either), true, false) => { + Some(PagedAttentionConfig::new(block_size, 512, either)?) } (_, _, _, _) => None, }; diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index c94696c601..32d73082fc 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -7,6 +7,7 @@ use axum::{ }; use candle_core::{quantized::GgmlDType, Device}; use clap::Parser; +use either::Either; use mistralrs_core::{ get_model_dtype, get_tgt_non_granular_index, initialize_logging, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, @@ -121,11 +122,17 @@ struct Args { #[arg(long = "isq", value_parser = parse_isq)] in_situ_quant: Option, - /// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to to the - /// available GPU memory. Paged Attention is only supported on CUDA and is always automatically activated. + /// GPU memory to allocate for KV cache with Paged Attention in MBs. If this is not set and the device is CUDA, it will default to + /// using `pa-gpu-mem-usage` set to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated. #[arg(long = "pa-gpu-mem")] paged_attn_gpu_mem: Option, + /// Percentage of GPU memory to utilize after allocation of KV cache with Paged Attention, from 0 to 1. + /// If this is not set and the device is CUDA, it will default to `0.9`. Paged Attention is only supported on CUDA and is always automatically activated. + /// This is always used over `pa-gpu-mem` if both are specified. + #[arg(long = "pa-gpu-mem-usage")] + paged_attn_gpu_mem_usage: Option, + /// Block size (number of tokens per block) for Paged Attention. If this is not set and the device is CUDA, it will default to 32. /// Paged Attention is only supported on CUDA and is always automatically activated. #[arg(long = "pa-blk-size")] @@ -338,16 +345,32 @@ async fn main() -> Result<()> { let cache_config = match ( args.paged_attn_block_size, args.paged_attn_gpu_mem, + args.paged_attn_gpu_mem_usage, device.is_cuda(), args.no_paged_attn, ) { - (block_size, None, true, false) => Some(PagedAttentionConfig::new( - block_size, 512, None, // Autodetermine KV cache size + (block_size, None, None, true, false) => Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(0.9), // NOTE(EricLBuehler): default is to use 90% of memory + )?), + (block_size, Some(m), None, true, false) => { + Some(PagedAttentionConfig::new(block_size, 512, Either::Left(m))?) + } + (block_size, None, Some(f), true, false) => Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(f), )?), - (block_size, Some(gpu_mem), _, false) => { - Some(PagedAttentionConfig::new(block_size, 512, Some(gpu_mem))?) + (block_size, Some(_m), Some(f), true, false) => { + info!("Both memory size and usage were specified, defaulting to the usage value."); + Some(PagedAttentionConfig::new( + block_size, + 512, + Either::Right(f), + )?) } - (_, _, _, _) => None, + (_, _, _, _, _) => None, }; let pipeline = loader.load_model_from_hf( diff --git a/mistralrs/examples/paged_attn/main.rs b/mistralrs/examples/paged_attn/main.rs index b73bfe1b4d..dafb8e0e95 100644 --- a/mistralrs/examples/paged_attn/main.rs +++ b/mistralrs/examples/paged_attn/main.rs @@ -1,3 +1,9 @@ +//! This is simply an example of configuring the paged attention. +//! +//! Paged attention is used by default on all CUDA devices. +//! +//! Otherwise, it defaults to 90% usage and block size = 32. + use either::Either; use indexmap::IndexMap; use std::sync::Arc; @@ -43,7 +49,11 @@ fn setup() -> anyhow::Result> { false, DeviceMapMetadata::dummy(), None, - Some(PagedAttentionConfig::new(Some(32), 1024, None)?), // Automatically determine memory usage + Some(PagedAttentionConfig::new( + Some(32), + 1024, + Either::Right(0.9), + )?), // Automatically determine memory usage )?; let config = pipeline .blocking_lock() From 7c50b68bb6423cebddc03bc0f1c374bfe1f1dc2b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:51:47 -0400 Subject: [PATCH 3/3] Include block engine in paged attn metadata (#576) Helps with Phi3V --- .../src/dummy_paged_attention/scheduler.rs | 3 ++ mistralrs-core/src/engine/mod.rs | 3 +- .../src/paged_attention/scheduler.rs | 3 ++ .../src/pipeline/inputs_processor.rs | 33 ++++++++++--------- .../src/scheduler/default_scheduler.rs | 5 ++- mistralrs-core/src/scheduler/mod.rs | 5 +-- .../vision_models/idefics2_input_processor.rs | 6 ++-- .../llava/llava_inputs_processor.rs | 11 +++++-- .../llava/llava_next_inputs_processor.rs | 11 +++++-- .../vision_models/phi3_inputs_processor.rs | 11 +++++-- 10 files changed, 59 insertions(+), 32 deletions(-) diff --git a/mistralrs-core/src/dummy_paged_attention/scheduler.rs b/mistralrs-core/src/dummy_paged_attention/scheduler.rs index 551c69516e..4b19a250ac 100644 --- a/mistralrs-core/src/dummy_paged_attention/scheduler.rs +++ b/mistralrs-core/src/dummy_paged_attention/scheduler.rs @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler { fn free_finished_sequence_groups(&mut self) { self.free_finished_sequence_groups() } + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + Some(&mut self.block_engine) + } } diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 36cc9b9a97..ee2597c643 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -271,13 +271,12 @@ impl Engine { let res = { let mut pipeline = get_mut_arcmutex!(self.pipeline); - let block_tables = self.scheduler.block_tables().unwrap(); let block_size = self.scheduler.block_size().unwrap(); let metadata = PagedAttentionMeta { - block_tables, block_size, sliding_window: pipeline.get_metadata().sliding_window, + block_engine: self.scheduler.block_engine().unwrap(), }; pipeline diff --git a/mistralrs-core/src/paged_attention/scheduler.rs b/mistralrs-core/src/paged_attention/scheduler.rs index 551c69516e..4b19a250ac 100644 --- a/mistralrs-core/src/paged_attention/scheduler.rs +++ b/mistralrs-core/src/paged_attention/scheduler.rs @@ -345,4 +345,7 @@ impl Scheduler for PagedAttentionScheduler { fn free_finished_sequence_groups(&mut self) { self.free_finished_sequence_groups() } + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + Some(&mut self.block_engine) + } } diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index d70010117e..e3675696a9 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -47,7 +47,7 @@ pub mod text_models_inputs_processor { use crate::{ layers::set_use_matmul_via_f16, - paged_attention::{BlockTables, _PAD_SLOT_ID}, + paged_attention::{BlockEngine, _PAD_SLOT_ID}, sequence::Sequence, }; @@ -71,11 +71,10 @@ pub mod text_models_inputs_processor { Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg) } - #[derive(Clone)] pub struct PagedAttentionMeta<'a> { - pub block_tables: &'a BlockTables, pub sliding_window: Option, pub block_size: usize, + pub block_engine: &'a mut BlockEngine, } #[derive(Clone, Debug)] @@ -101,7 +100,7 @@ pub mod text_models_inputs_processor { input_seqs: &[&mut Sequence], device: &Device, last_n_context_len: Option<(usize, usize)>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, ) -> Result { let max_len = input_seqs .iter() @@ -128,8 +127,8 @@ pub mod text_models_inputs_processor { seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap()); - if let Some(ref paged_attn_metadata) = paged_attn_metadata { - let table = paged_attn_metadata.block_tables.get(seq.id()); + if let Some(paged_attn_metadata) = &mut paged_attn_metadata { + let table = paged_attn_metadata.block_engine.block_tables.get(seq.id()); let prompt_len = seq.len(); if table.is_none() { @@ -240,7 +239,7 @@ pub mod text_models_inputs_processor { device: &Device, no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, ) -> Result { if no_kv_cache { return get_prompt_input( @@ -269,8 +268,12 @@ pub mod text_models_inputs_processor { seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap()); - if let Some(ref paged_attn_metadata) = paged_attn_metadata { - let table = paged_attn_metadata.block_tables.get(seq.id()).unwrap(); + if let Some(paged_attn_metadata) = &mut paged_attn_metadata { + let table = paged_attn_metadata + .block_engine + .block_tables + .get(seq.id()) + .unwrap(); let table = table .iter() @@ -393,7 +396,7 @@ pub mod text_models_inputs_processor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, _: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> Result> { if is_xlora && !is_prompt { let InputMetadata { @@ -411,7 +414,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata.clone(), + paged_attn_metadata.as_mut(), )?; let InputMetadata { input: input_ids, @@ -429,7 +432,7 @@ pub mod text_models_inputs_processor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, @@ -458,7 +461,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids: input_ids.clone(), @@ -487,7 +490,7 @@ pub mod text_models_inputs_processor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, @@ -517,7 +520,7 @@ pub mod text_models_inputs_processor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )?; Ok(Box::new(ModelInputs { input_ids, diff --git a/mistralrs-core/src/scheduler/default_scheduler.rs b/mistralrs-core/src/scheduler/default_scheduler.rs index c20eb9120f..f36599b9a0 100644 --- a/mistralrs-core/src/scheduler/default_scheduler.rs +++ b/mistralrs-core/src/scheduler/default_scheduler.rs @@ -6,7 +6,7 @@ use std::{ use crate::{ engine::TERMINATE_ALL_NEXT_STEP, - paged_attention::BlockTables, + paged_attention::{BlockEngine, BlockTables}, sequence::{Sequence, SequenceState, StopReason}, }; @@ -323,4 +323,7 @@ impl Scheduler for DefaultScheduler> { None } fn free_finished_sequence_groups(&mut self) {} + fn block_engine(&mut self) -> Option<&mut BlockEngine> { + None + } } diff --git a/mistralrs-core/src/scheduler/mod.rs b/mistralrs-core/src/scheduler/mod.rs index dcfc41d8a3..25aff8a07b 100644 --- a/mistralrs-core/src/scheduler/mod.rs +++ b/mistralrs-core/src/scheduler/mod.rs @@ -4,8 +4,8 @@ pub use default_scheduler::{DefaultScheduler, DefaultSchedulerMethod, DefaultSch use crate::{ paged_attention::{ - BlockTables, CacheConfig, PagedAttentionScheduler, PagedAttentionSchedulerConfig, - PagedAttentionSchedulerOutput, + BlockEngine, BlockTables, CacheConfig, PagedAttentionScheduler, + PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, }, sequence::Sequence, }; @@ -55,4 +55,5 @@ pub trait Scheduler { // Paged Attention metadata fn block_tables(&self) -> Option<&BlockTables>; fn block_size(&self) -> Option; + fn block_engine(&mut self) -> Option<&mut BlockEngine>; } diff --git a/mistralrs-core/src/vision_models/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2_input_processor.rs index 30ee9031ae..7318835ea2 100644 --- a/mistralrs-core/src/vision_models/idefics2_input_processor.rs +++ b/mistralrs-core/src/vision_models/idefics2_input_processor.rs @@ -118,7 +118,7 @@ impl InputsProcessor for Idefics2ImageProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -142,7 +142,7 @@ impl InputsProcessor for Idefics2ImageProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -154,7 +154,7 @@ impl InputsProcessor for Idefics2ImageProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; let config = other_config.expect("Need a PreProcessorConfig config."); diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs index 0aaa1ae4ee..4bdaa94cbf 100644 --- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs @@ -81,7 +81,7 @@ impl InputsProcessor for LLaVAInputProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -212,6 +212,11 @@ impl InputsProcessor for LLaVAInputProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -229,7 +234,7 @@ impl InputsProcessor for LLaVAInputProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -238,7 +243,7 @@ impl InputsProcessor for LLaVAInputProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; Ok(Box::new(ModelInputs { diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs index f35304db6a..0568ddfd4d 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs @@ -88,7 +88,7 @@ impl InputsProcessor for LLaVANextInputProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -256,6 +256,11 @@ impl InputsProcessor for LLaVANextInputProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -273,7 +278,7 @@ impl InputsProcessor for LLaVANextInputProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -282,7 +287,7 @@ impl InputsProcessor for LLaVANextInputProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? }; Ok(Box::new(ModelInputs { diff --git a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs index e0a7ffd63c..83ee5785af 100644 --- a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs @@ -76,7 +76,7 @@ impl InputsProcessor for Phi3InputsProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, other_config: Option>, - paged_attn_metadata: Option>, + mut paged_attn_metadata: Option>, ) -> anyhow::Result> { if is_xlora { anyhow::bail!("Cannot make inputs for X-LoRA vision model."); @@ -245,6 +245,11 @@ impl InputsProcessor for Phi3InputsProcessor { .map(|x| if *x < 0 { 0u32 } else { *x as u32 }) .collect::>(), ); + if let Some(ref mut metadata) = paged_attn_metadata { + // Free and then reallocate as appropriate + metadata.block_engine.free_sequence(*seq.id()); + metadata.block_engine.allocate(*seq); + } toks.push(input_ids); } @@ -262,7 +267,7 @@ impl InputsProcessor for Phi3InputsProcessor { input_seqs, device, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? } else { get_completion_input( @@ -271,7 +276,7 @@ impl InputsProcessor for Phi3InputsProcessor { device, no_kv_cache, last_n_context_len, - paged_attn_metadata, + paged_attn_metadata.as_mut(), )? };