Skip to content

Commit

Permalink
Preallocated KV cache - optimize decode (#916)
Browse files Browse the repository at this point in the history
* Initial work on preallocated kv cache

* Dynamically growing now

* More correct prompt exec calculation

* Dummy run

* Typos

* Implement for the rest of the normal models

* Add it to the gguf models

* Fix
  • Loading branch information
EricLBuehler authored Nov 18, 2024
1 parent 5a53001 commit 09a0fd7
Show file tree
Hide file tree
Showing 54 changed files with 1,197 additions and 475 deletions.
1 change: 1 addition & 0 deletions mistralrs-core/src/dummy_paged_attention/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub trait ModelConfigLike {
}
}

#[derive(Clone)]
pub struct ModelConfigMetadata {
pub num_layers: usize,
pub hidden_size: usize,
Expand Down
69 changes: 55 additions & 14 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use candle_core::Tensor;
use once_cell::sync::Lazy;
use std::{
collections::HashMap,
Expand All @@ -13,7 +14,7 @@ use crate::{
aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx},
pipeline::{
text_models_inputs_processor::PagedAttentionMeta, AdapterInstruction, CacheBackendMetadata,
CacheInstruction,
CacheInstruction, EitherCache, NormalCache,
},
request::NormalRequest,
response::CompletionChoice,
Expand Down Expand Up @@ -162,6 +163,7 @@ impl Engine {
CacheInstruction::Out
} else {
CacheInstruction::Reset {
load_preallocated_cache: false,
reset_non_granular: false,
adapter_inst: AdapterInstruction::None,
}
Expand Down Expand Up @@ -203,15 +205,15 @@ impl Engine {
}

if scheduled.prompt.len() > 0 {
let throughput_start = Instant::now();
let logits = {
let prompt_exec_time = {
let mut pipeline = get_mut_arcmutex!(self.pipeline);

// Run the prompt seqs
let post_op = if !self.no_kv_cache {
CacheInstruction::Out
} else {
CacheInstruction::Reset {
load_preallocated_cache: false,
reset_non_granular: false,
adapter_inst: AdapterInstruction::None,
}
Expand All @@ -232,6 +234,7 @@ impl Engine {
rng.clone(),
CacheBackendMetadata::DefaultInstructions {
pre_op: CacheInstruction::Reset {
load_preallocated_cache: true,
reset_non_granular: false,
adapter_inst,
},
Expand All @@ -241,16 +244,15 @@ impl Engine {
.await
};

handle_pipeline_forward_error!(
let prompt_exec_time = handle_pipeline_forward_error!(
"prompt step",
logits,
prompt_exec_time,
&mut scheduled.prompt,
self.pipeline,
'lp,
self.prefix_cacher
);

let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
prompt_ts = Some(
Expand All @@ -259,9 +261,7 @@ impl Engine {
.iter()
.map(|seq| seq.get_toks().len())
.sum::<usize>() as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64(),
/ prompt_exec_time.as_secs_f64(),
);
}

Expand All @@ -280,8 +280,8 @@ impl Engine {
.as_millis();
#[allow(clippy::cast_precision_loss)]
let prompt_tok_per_sec =
seq.len() as f32 / (now - seq.timestamp()) as f32;
seq.prompt_tok_per_sec = prompt_tok_per_sec * 1000.;
seq.len() as f32 / prompt_exec_time.as_secs_f32();
seq.prompt_tok_per_sec = prompt_tok_per_sec;
seq.prompt_timestamp = Some(now);
}
last_completion_ids = vec![];
Expand Down Expand Up @@ -733,9 +733,6 @@ impl Engine {
is_chat,
best_of,
)));
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");

let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();

Expand Down Expand Up @@ -790,6 +787,49 @@ impl Engine {
.tok_trie
.as_ref()
.map(|x| (**x).clone());

let cache = get_mut_arcmutex!(self.pipeline).cache().clone();
let seq_preallocated_cache = if let EitherCache::Normal(_cache) = cache {
let metadata = get_mut_arcmutex!(self.pipeline).get_metadata();
let model_metadata = metadata
.model_metadata
.as_ref()
.expect("If a model has a NormalCache it must have a model metadata");
let max_seq_len = NormalCache::CACHE_GROW_SIZE;
let kv_shape = (
1usize,
model_metadata.num_kv_heads(),
max_seq_len,
model_metadata.head_dim(),
);
let dtype = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.activation_dtype;
let seq_cache =
Tensor::zeros(kv_shape, dtype, &get_mut_arcmutex!(self.pipeline).device());
let seq_cache = match seq_cache {
Ok(x) => x,
Err(_) => {
request
.response
.send(Response::InternalError(
"Failed to allocate preallocated KV cache."
.to_string()
.into(),
))
.await
.expect("Expected receiver.");
return;
}
};
Some(seq_cache)
} else {
None
};

let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");
let seq = Sequence::new_waiting(
prompt_tokens.clone(),
prompt_text.clone(),
Expand Down Expand Up @@ -821,6 +861,7 @@ impl Engine {
image_generation_format,
seq_step_type,
diffusion_params.clone(),
seq_preallocated_cache,
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
seq.prefill(
Expand Down
9 changes: 9 additions & 0 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::ops::Add;

use candle_core::{DType, Device, Result, Tensor, WithDType};

use crate::pipeline::KvCache;

// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
pub struct CausalMasker;

Expand Down Expand Up @@ -47,6 +49,13 @@ impl<'a> PastKvLenCache for &'a [Option<(Tensor, Tensor)>] {
}
}

impl PastKvLenCache for Vec<KvCache> {
fn get_past_kv_len(&self) -> Result<usize> {
let kv_cache_1 = &self[0];
Ok(kv_cache_1.current_seq_len())
}
}

impl<'a> PastKvLenCache for &'a [usize] {
fn get_past_kv_len(&self) -> Result<usize> {
if self.windows(2).all(|w| w[0] == w[1]) {
Expand Down
45 changes: 45 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub use pipeline::ModelCategory;
pub use pipeline::Pipeline;
#[cfg(feature = "pyo3_macros")]
use pyo3::exceptions::PyValueError;
use std::time::Instant;
use std::{
cell::RefCell,
error::Error,
Expand All @@ -22,6 +23,8 @@ use std::{
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc::{channel, Sender};
use tracing::info;
use tracing::warn;

mod aici;
mod cuda;
Expand Down Expand Up @@ -358,6 +361,48 @@ impl MistralRs {

let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);

// Do a dummy run
if matches!(category, ModelCategory::Text | ModelCategory::Vision { .. }) {
let clone_sender = sender.read().unwrap().clone();
tokio::task::block_in_place(|| {
let (tx, mut rx) = channel(1);
let req = Request::Normal(NormalRequest {
id: 0,
messages: RequestMessage::Completion {
text: "dummy".to_string(),
echo_prompt: false,
best_of: 1,
},
sampling_params: SamplingParams {
max_len: Some(1),
..SamplingParams::deterministic()
},
response: tx,
return_logprobs: false,
is_streaming: true,
constraint: Constraint::None,
suffix: None,
adapters: None,
tool_choice: None,
tools: None,
logits_processors: None,
});
info!("Beginning dummy run.");
let start = Instant::now();
clone_sender.blocking_send(req).unwrap();

if let Some(_resp) = rx.blocking_recv() {
let end = Instant::now();
info!(
"Dummy run completed in {}s.",
end.duration_since(start).as_secs_f64()
);
} else {
warn!("Dummy run failed!");
}
});
}

Arc::new(Self {
engine_id,
sender,
Expand Down
21 changes: 12 additions & 9 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
Expand Down Expand Up @@ -254,7 +254,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -321,7 +321,7 @@ impl Attention {
)?
}
None => {
let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
Expand Down Expand Up @@ -399,7 +399,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -430,7 +430,7 @@ pub struct Model {
lm_head: Arc<dyn QuantMethod>,
hidden_size: usize,
device: Device,
cache: Cache,
cache: EitherCache,
max_seq_len: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
cfg: ModelConfigMetadata,
Expand Down Expand Up @@ -529,7 +529,10 @@ impl Model {
))?),
device: normal_loading_metadata.real_device,
hidden_size: cfg.hidden_size,
cache: Cache::new(cfg.num_hidden_layers, false),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: default_max_position_embeddings(),
mapper,
cfg: ModelConfigMetadata {
Expand All @@ -554,13 +557,13 @@ impl Model {
) -> Result<Tensor> {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let mut cache = self.cache.lock();
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
.unwrap_or(cache as &dyn PastKvLenCache),
xs.dtype(),
)?;
for (i, layer) in self.layers.iter().enumerate() {
Expand Down Expand Up @@ -673,7 +676,7 @@ impl NormalModel for Model {
) -> Result<Tensor> {
unimplemented!()
}
fn cache(&self) -> &Cache {
fn cache(&self) -> &EitherCache {
&self.cache
}
fn device(&self) -> &Device {
Expand Down
Loading

0 comments on commit 09a0fd7

Please sign in to comment.