Skip to content

Commit

Permalink
Fixes for kv cache growing strategy (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Nov 18, 2024
1 parent 09a0fd7 commit 32ec1fb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ impl Sdpa {
);
}

let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?.contiguous()?;
let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?.contiguous()?;
let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
if !get_use_matmul_via_f16() {
#[cfg(feature = "cuda")]
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ impl Engine {
.model_metadata
.as_ref()
.expect("If a model has a NormalCache it must have a model metadata");
let max_seq_len = NormalCache::CACHE_GROW_SIZE;
let n_tokens = prompt_tokens.len();
let required_blocks =
(n_tokens + NormalCache::CACHE_GROW_SIZE - 1) / NormalCache::CACHE_GROW_SIZE;
let max_seq_len = required_blocks * NormalCache::CACHE_GROW_SIZE;
let kv_shape = (
1usize,
model_metadata.num_kv_heads(),
Expand Down
8 changes: 6 additions & 2 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,13 @@ impl SingleCache {
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad);
};
// Expand kv cache
if self.current_seq_len + seq_len > self.capacity_seq_len {
self.capacity_seq_len += NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len < self.max_seq_len {
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed =
(diff + NormalCache::CACHE_GROW_SIZE - 1) / NormalCache::CACHE_GROW_SIZE;
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
Expand Down

0 comments on commit 32ec1fb

Please sign in to comment.