Skip to content

Commit

Permalink
Implement for the rest of the normal models
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 18, 2024
1 parent f288629 commit e366096
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 107 deletions.
17 changes: 10 additions & 7 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
Expand Down Expand Up @@ -254,7 +254,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -321,7 +321,7 @@ impl Attention {
)?
}
None => {
let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
Expand Down Expand Up @@ -399,7 +399,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -529,7 +529,10 @@ impl Model {
))?),
device: normal_loading_metadata.real_device,
hidden_size: cfg.hidden_size,
cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: default_max_position_embeddings(),
mapper,
cfg: ModelConfigMetadata {
Expand All @@ -554,13 +557,13 @@ impl Model {
) -> Result<Tensor> {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let mut cache = self.cache.full().lock();
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
.unwrap_or(cache as &dyn PastKvLenCache),
xs.dtype(),
)?;
for (i, layer) in self.layers.iter().enumerate() {
Expand Down
23 changes: 10 additions & 13 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
Expand Down Expand Up @@ -262,7 +262,7 @@ impl Attention {
sliding_attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -336,14 +336,8 @@ impl Attention {
}
None => {
// self.sliding_window is None if !self.use_sliding_window
let (k, v, mask) = Cache::update_kv_cache_sliding_window(
kv_cache,
k,
v,
mask,
self.sliding_window,
false,
)?;
let (k, v, mask) =
kv_cache.append_sliding_window(&k, &v, mask, self.sliding_window)?;

Sdpa.run_attention(
&q,
Expand Down Expand Up @@ -437,7 +431,7 @@ impl DecoderLayer {
sliding_attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -581,7 +575,10 @@ impl Model {
))?),
device: normal_loading_metadata.real_device,
hidden_size: cfg.hidden_size,
cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: cfg.max_position_embeddings,
mapper,
sliding_window: cfg.sliding_window,
Expand All @@ -608,7 +605,7 @@ impl Model {
) -> Result<Tensor> {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let mut cache = self.cache.full().lock();
let cache = &mut self.cache.normal().0;
let attention_mask =
CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?;
let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl CausalSelfAttention {
)?
}
None => {
let (k, v) = kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
Expand Down
25 changes: 11 additions & 14 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
Expand Down Expand Up @@ -231,7 +231,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -298,14 +298,8 @@ impl Attention {
)?
}
None => {
let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
kv_cache,
k,
v,
attention_mask,
self.sliding_window,
false,
)?;
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;

Sdpa.run_attention(
&q,
Expand Down Expand Up @@ -383,7 +377,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -547,7 +541,10 @@ impl Model {
lm_head,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: cfg.max_position_embeddings,
mapper,
cfg: ModelConfigMetadata {
Expand Down Expand Up @@ -597,13 +594,13 @@ impl Model {
flash_params: &FlashParams,
) -> Result<Tensor> {
let mut xs = input_embeds;
let mut cache = self.cache.full().lock();
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
)?;
Expand Down
25 changes: 11 additions & 14 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
Expand Down Expand Up @@ -126,7 +126,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -193,14 +193,8 @@ impl Attention {
)?
}
None => {
let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
kv_cache,
k,
v,
attention_mask,
self.sliding_window,
false,
)?;
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;

Sdpa.run_attention(
&q,
Expand Down Expand Up @@ -435,7 +429,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -579,7 +573,10 @@ impl Model {
lm_head,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: cfg.max_position_embeddings,
mapper,
cfg: ModelConfigMetadata {
Expand All @@ -603,13 +600,13 @@ impl Model {
flash_params: &FlashParams,
) -> Result<Tensor> {
let mut xs = self.embed_tokens.forward(input_ids)?;
let mut cache = self.cache.full().lock();
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
)?;
Expand Down
17 changes: 10 additions & 7 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
Expand Down Expand Up @@ -236,7 +236,7 @@ impl Attention {
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -317,7 +317,7 @@ impl Attention {
)?
}
None => {
let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
}
Expand Down Expand Up @@ -383,7 +383,7 @@ impl DecoderLayer {
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
Expand Down Expand Up @@ -512,7 +512,10 @@ impl Model {
layers,
final_layernorm,
lm_head,
cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
device: normal_loading_metadata.real_device,
max_seq_len: cfg.max_position_embeddings,
mapper,
Expand All @@ -537,13 +540,13 @@ impl Model {
flash_params: &FlashParams,
) -> Result<Tensor> {
let mut xs = input_ids.apply(&self.embed_tokens)?;
let mut cache = self.cache.full().lock();
let cache = &mut self.cache.normal().0;
let mask = CausalMasker.make_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
.unwrap_or(cache as &dyn PastKvLenCache),
xs.dtype(),
)?;
for (i, layer) in self.layers.iter().enumerate() {
Expand Down
Loading

0 comments on commit e366096

Please sign in to comment.