Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 18, 2024
1 parent 6dd7147 commit bc3449a
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 37 deletions.
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ impl Cache {
}
}

pub struct DefaultCacheManager;
pub struct FullCacheManager;

enum SeqCache {
Normal,
Expand Down Expand Up @@ -655,7 +655,7 @@ fn clone_out_cache(
}
}

impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for DefaultCacheManager {
impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCacheManager {
fn clone_in_cache(
&self,
pipeline: &T,
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::cache_manager::DefaultCacheManager;
use super::cache_manager::FullCacheManager;
use super::{
get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, QuantizationKind, TokenSource,
Expand Down Expand Up @@ -456,10 +456,10 @@ impl IsqPipelineMixin for GGMLPipeline {

impl CacheManagerMixin for GGMLPipeline {
fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
}
fn clone_out_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
}
fn set_none_cache(
&self,
Expand All @@ -468,7 +468,7 @@ impl CacheManagerMixin for GGMLPipeline {
modify_draft_cache: bool,
load_preallocated_cache: bool,
) {
DefaultCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
if reset_non_granular {
self.reset_non_granular_state()
}
Expand Down
33 changes: 25 additions & 8 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::cache_manager::DefaultCacheManager;
use super::cache_manager::{FullCacheManager, NormalCacheManager};
use super::{
get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
Expand Down Expand Up @@ -487,12 +487,12 @@ impl Loader for GGUFLoader {
let tok_trie: Arc<TokTrie> = build_tok_trie(tokenizer.clone()).into();
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.normal().0.len(),
Model::Phi2(ref model) => model.cache.full().lock().len(),
Model::Phi2(ref model) => model.cache.normal().0.len(),
Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
Model::Phi3(ref model) => model.cache.full().lock().len(),
Model::Phi3(ref model) => model.cache.normal().0.len(),
Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
Model::Starcoder2(ref model) => model.cache.full().lock().len(),
Model::Qwen2(ref model) => model.cache.full().lock().len(),
Model::Starcoder2(ref model) => model.cache.normal().0.len(),
Model::Qwen2(ref model) => model.cache.normal().0.len(),
};

if chat_template.bos_token.is_none() && bos.is_some() {
Expand Down Expand Up @@ -570,10 +570,18 @@ impl IsqPipelineMixin for GGUFPipeline {

impl CacheManagerMixin for GGUFPipeline {
fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
if matches!(self.cache(), EitherCache::Full(_)) {
FullCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
} else {
NormalCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
}
}
fn clone_out_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
if matches!(self.cache(), EitherCache::Full(_)) {
FullCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
} else {
NormalCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
}
}
fn set_none_cache(
&self,
Expand All @@ -582,7 +590,16 @@ impl CacheManagerMixin for GGUFPipeline {
modify_draft_cache: bool,
load_preallocated_cache: bool,
) {
DefaultCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
if matches!(self.cache(), EitherCache::Full(_)) {
FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
} else {
NormalCacheManager.set_none_cache(
self,
seqs,
modify_draft_cache,
load_preallocated_cache,
);
}
if reset_non_granular {
self.reset_non_granular_state()
}
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::cache_manager::{DefaultCacheManager, NormalCacheManager};
use super::cache_manager::{FullCacheManager, NormalCacheManager};
use super::{
get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
Expand Down Expand Up @@ -519,14 +519,14 @@ impl IsqPipelineMixin for NormalPipeline {
impl CacheManagerMixin for NormalPipeline {
fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
if matches!(self.model.cache(), EitherCache::Full(_)) {
DefaultCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
} else {
NormalCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
}
}
fn clone_out_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
if matches!(self.model.cache(), EitherCache::Full(_)) {
DefaultCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
} else {
NormalCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
}
Expand All @@ -539,7 +539,7 @@ impl CacheManagerMixin for NormalPipeline {
load_preallocated_cache: bool,
) {
if matches!(self.model.cache(), EitherCache::Full(_)) {
DefaultCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
} else {
NormalCacheManager.set_none_cache(
self,
Expand Down
22 changes: 7 additions & 15 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
};

use super::{
cache_manager::DefaultCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
cache_manager::FullCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
AdapterActivationMixin, AnyMoePipelineMixin, CacheBackendMetadata, CacheInstruction,
CacheManager, CacheManagerMixin, EitherCache, ForwardInputsResult, GeneralMetadata,
IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths, PreProcessingMixin,
Expand Down Expand Up @@ -246,20 +246,12 @@ impl IsqPipelineMixin for SpeculativePipeline {
// TODO: correct handling of cloning in and out for normal cache
impl CacheManagerMixin for SpeculativePipeline {
fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_in_cache(
&*get_mut_arcmutex!(self.draft),
seqs,
modify_draft_cache,
);
DefaultCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
FullCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, modify_draft_cache);
FullCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
}
fn clone_out_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_out_cache(
&*get_mut_arcmutex!(self.draft),
seqs,
modify_draft_cache,
);
DefaultCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
FullCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, modify_draft_cache);
FullCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
}
fn set_none_cache(
&self,
Expand All @@ -268,13 +260,13 @@ impl CacheManagerMixin for SpeculativePipeline {
modify_draft_cache: bool,
load_preallocated_cache: bool,
) {
DefaultCacheManager.set_none_cache(
FullCacheManager.set_none_cache(
&*get_mut_arcmutex!(self.draft),
seqs,
modify_draft_cache,
load_preallocated_cache,
);
DefaultCacheManager.set_none_cache(
FullCacheManager.set_none_cache(
&*get_mut_arcmutex!(self.target),
seqs,
false,
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::cache_manager::DefaultCacheManager;
use super::cache_manager::FullCacheManager;
use super::isq::UqffFullSer;
use super::{
get_model_paths, get_xlora_paths, AdapterActivationMixin, AnyMoePipelineMixin, CacheManager,
Expand Down Expand Up @@ -436,10 +436,10 @@ impl IsqPipelineMixin for VisionPipeline {

impl CacheManagerMixin for VisionPipeline {
fn clone_in_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_in_cache(self, seqs, modify_draft_cache)
}
fn clone_out_cache(&self, seqs: &mut [&mut Sequence], modify_draft_cache: bool) {
DefaultCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
FullCacheManager.clone_out_cache(self, seqs, modify_draft_cache)
}
fn set_none_cache(
&self,
Expand All @@ -448,7 +448,7 @@ impl CacheManagerMixin for VisionPipeline {
modify_draft_cache: bool,
load_preallocated_cache: bool,
) {
DefaultCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, load_preallocated_cache);
if reset_non_granular {
self.reset_non_granular_state()
}
Expand Down

0 comments on commit bc3449a

Please sign in to comment.