diff --git a/Cargo.lock b/Cargo.lock index 149bb2f73..055ea33e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,8 +386,8 @@ checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "candle-core" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=77a6cc6#77a6cc66e35d8f242b6b943b765a6a87e2ca0e8c" dependencies = [ "accelerate-src", "byteorder", @@ -417,8 +417,8 @@ dependencies = [ [[package]] name = "candle-flash-attn" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=77a6cc6#77a6cc66e35d8f242b6b943b765a6a87e2ca0e8c" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -428,16 +428,16 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=77a6cc6#77a6cc66e35d8f242b6b943b765a6a87e2ca0e8c" dependencies = [ "bindgen_cuda 0.1.5", ] [[package]] name = "candle-metal-kernels" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=77a6cc6#77a6cc66e35d8f242b6b943b765a6a87e2ca0e8c" dependencies = [ "metal 0.27.0", "once_cell", @@ -447,8 +447,8 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=11495ab#11495abeba0c1d27b680168edd0cd52189a7ca30" +version = "0.8.0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=77a6cc6#77a6cc66e35d8f242b6b943b765a6a87e2ca0e8c" dependencies = [ "accelerate-src", "candle-core", diff --git a/Cargo.toml b/Cargo.toml index 415ec2248..bfd242879 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 228a29c13..9e7e88dd3 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "11495ab", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "77a6cc6", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index faf23da54..ad0ddb787 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -92,24 +92,31 @@ fn naive_sdpa( head_dim: usize, sdpa_params: &SdpaParams, ) -> Result { - let mut att = MatMul.matmul_affine_div( - &q.contiguous()?, - &k.t()?.contiguous()?, - (head_dim as f64).sqrt(), - )?; - if let Some(softcap) = sdpa_params.softcap { - att = (att / softcap as f64)?; - att = att.tanh()?; - att = (att * softcap as f64)?; - } + if let Some(mask) = mask { + let mut att = MatMul.matmul(q, &k.t()?)?; + if let Some(softcap) = sdpa_params.softcap { + att = (att / softcap as f64)?; + att = att.tanh()?; + att = (att * softcap as f64)?; + } + + let att = candle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?; + MatMul.matmul(&att, v) + } else { + let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?; + if let Some(softcap) = sdpa_params.softcap { + att = (att / softcap as f64)?; + att = att.tanh()?; + att = (att * softcap as f64)?; + } - let att = match mask { - Some(m) => att.broadcast_add(m)?, - None => att, - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - MatMul.matmul(&att, &v.contiguous()?) + let att = match mask { + Some(m) => att.broadcast_add(m)?, + None => att, + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + MatMul.matmul(&att, v) + } } pub struct SdpaParams { diff --git a/mistralrs-core/src/layers_masker.rs b/mistralrs-core/src/layers_masker.rs index c41f2a486..40b58b20e 100644 --- a/mistralrs-core/src/layers_masker.rs +++ b/mistralrs-core/src/layers_masker.rs @@ -122,31 +122,28 @@ impl CausalMasker { return Ok(k_cache_1.dims()[2]); } - pub fn make_causal_mask_as_attn_bias( + pub fn make_causal_mask_matrix( &self, input_ids: &Tensor, cache: &dyn PastKvLenCache, dtype: DType, - n_attn_heads: usize, ) -> Result> { let past_kv_len = cache.get_past_kv_len()?; - let (b_sz, tgt_len) = input_ids.dims2()?; + let (_b_sz, tgt_len) = input_ids.dims2()?; if tgt_len == 1 { return Ok(None); } let causal_mask = { - let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + let mask = self + .make_mask(tgt_len, past_kv_len, input_ids.device())? .to_dtype(DType::U8)?; Some(mask) }; let zero = Tensor::new(0.0f32, input_ids.device())?; let causal_mask: Option> = causal_mask.map(|mask| { - let mask = - mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?; + let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?; // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf) let mask = masked_fill( &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?, @@ -164,20 +161,19 @@ impl CausalMasker { Ok(mask) } - pub fn make_causal_mask_with_sliding_window_as_attn_bias( + pub fn make_sliding_window_causal_mask_matrix( &self, input_ids: &Tensor, cache: &dyn PastKvLenCache, sliding_window: Option, dtype: DType, - n_attn_heads: usize, ) -> Result> { if sliding_window.is_none() { - return self.make_causal_mask_as_attn_bias(input_ids, cache, dtype, n_attn_heads); + return self.make_causal_mask_matrix(input_ids, cache, dtype); } let sliding_window = sliding_window.unwrap(); let past_kv_len = cache.get_past_kv_len()?; - let (b_sz, tgt_len) = input_ids.dims2()?; + let (_b_sz, tgt_len) = input_ids.dims2()?; if tgt_len == 1 { return Ok(None); } @@ -186,9 +182,7 @@ impl CausalMasker { let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; let diagonal = past_kv_len as isize - sliding_window as isize - 1; let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; - let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)? .to_dtype(DType::U8)?; Some(mask) @@ -196,8 +190,7 @@ impl CausalMasker { let zero = Tensor::new(0.0f32, input_ids.device())?; let causal_mask: Option> = causal_mask.map(|mask| { - let mask = - mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?; + let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?; // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf) let mask = masked_fill( &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?, @@ -216,60 +209,105 @@ impl CausalMasker { } #[deprecated( - since = "0.1.10", - note = "use `make_causal_mask_as_attn_bias` instead! \ - This is *not* compatible with `Sdpa`" + since = "0.3.3", + note = "use `make_causal_mask_matrix_as_attn_bias` instead. This is incompatible with `Sdpa`." )] - pub fn make_causal_mask( + pub fn make_causal_mask_as_attn_bias( &self, input_ids: &Tensor, - cache: &[Option<(Tensor, Tensor)>], + cache: &dyn PastKvLenCache, + dtype: DType, + n_attn_heads: usize, ) -> Result> { - let past_kv_len = self.calculate_past_kv_len(cache)?; + let past_kv_len = cache.get_past_kv_len()?; let (b_sz, tgt_len) = input_ids.dims2()?; if tgt_len == 1 { return Ok(None); } - let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? - .to_dtype(DType::U8)?; + let causal_mask = { + let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; + let mask = mask + .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + .to_dtype(DType::U8)?; + Some(mask) + }; - Ok(Some(mask)) + let zero = Tensor::new(0.0f32, input_ids.device())?; + let causal_mask: Option> = causal_mask.map(|mask| { + let mask = + mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?; + // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf) + let mask = masked_fill( + &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?, + &mask, + f32::NEG_INFINITY, + )?; + + Ok(mask) + }); + let mask: Option = if let Some(mask) = causal_mask { + Some(mask?) + } else { + None + }; + Ok(mask) } #[deprecated( - since = "0.1.10", - note = "use `make_causal_mask_with_sliding_window_as_attn_bias` instead! \ - This is *not* compatible with `Sdpa`" + since = "0.3.3", + note = "use `make_causal_mask_matrix_with_sliding_window_as_attn_bias` instead. This is incompatible with `Sdpa`." )] - pub fn make_causal_mask_with_sliding_window( + pub fn make_causal_mask_with_sliding_window_as_attn_bias( &self, input_ids: &Tensor, - cache: &[Option<(Tensor, Tensor)>], + cache: &dyn PastKvLenCache, sliding_window: Option, + dtype: DType, + n_attn_heads: usize, ) -> Result> { if sliding_window.is_none() { #[allow(deprecated)] - return self.make_causal_mask(input_ids, cache); + return self.make_causal_mask_as_attn_bias(input_ids, cache, dtype, n_attn_heads); } let sliding_window = sliding_window.unwrap(); - let past_kv_len = self.calculate_past_kv_len(cache)?; + let past_kv_len = cache.get_past_kv_len()?; let (b_sz, tgt_len) = input_ids.dims2()?; if tgt_len == 1 { return Ok(None); } - let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; - let diagonal = past_kv_len as isize - sliding_window as isize - 1; - let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; - let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? - .to_dtype(DType::U8)?; + let causal_mask = { + let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; + let diagonal = past_kv_len as isize - sliding_window as isize - 1; + let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; + let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?; + let mask = mask + .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + .to_dtype(DType::U8)?; - Ok(Some(mask)) + Some(mask) + }; + + let zero = Tensor::new(0.0f32, input_ids.device())?; + let causal_mask: Option> = causal_mask.map(|mask| { + let mask = + mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?; + // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf) + let mask = masked_fill( + &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?, + &mask, + f32::NEG_INFINITY, + )?; + + Ok(mask) + }); + let mask: Option = if let Some(mask) = causal_mask { + Some(mask?) + } else { + None + }; + Ok(mask) } pub fn apply_mask_one_and_zero( diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index 8a1b4ef1f..344eaebf8 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -555,14 +555,13 @@ impl Model { let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_as_attn_bias( + let attention_mask = CausalMasker.make_causal_mask_matrix( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 6448e6407..f0da39c91 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -609,20 +609,14 @@ impl Model { let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_as_attn_bias( + let attention_mask = + CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?; + let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, + Some(self.sliding_window), xs.dtype(), - self.layers[0].self_attn.num_heads, )?; - let sliding_attention_mask = CausalMasker - .make_causal_mask_with_sliding_window_as_attn_bias( - input_ids, - &*cache, - Some(self.sliding_window), - xs.dtype(), - self.layers[0].self_attn.num_heads, - )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; xs = layer.forward( diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 0f10d03e9..afc33dcc5 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -411,14 +411,13 @@ impl Llama { ) -> Result { let mut x = self.wte.forward(input_ids)?; let mut cache = self.kv_cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), x.dtype(), - self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { x = self.mapper.map(x, block_idx)?; diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index af86aa02d..e14087593 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -598,7 +598,7 @@ impl Model { ) -> Result { let mut xs = input_embeds; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -606,7 +606,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 147204187..2afbaf68c 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -604,7 +604,7 @@ impl Model { ) -> Result { let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -612,7 +612,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 9f07075cc..461e68fd8 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -538,14 +538,13 @@ impl Model { ) -> Result { let mut xs = input_ids.apply(&self.embed_tokens)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index c371ab0f9..8e88f82a2 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -537,7 +537,7 @@ impl Model { ) -> Result { let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -545,7 +545,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index 0f02c1c68..64b4c9665 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -659,7 +659,7 @@ impl Model { ) -> Result { let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -667,7 +667,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index d7f37ba32..96b12b4c9 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -655,14 +655,13 @@ impl ModelWeights { ) -> Result { let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( x, metadata .as_ref() .map(|(_, _)| &start_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), DType::F32, - self.layers[0].n_head, )?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index e5f7d868f..900838052 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -347,12 +347,7 @@ impl ModelWeights { ) -> Result { let mut xs = self.tok_embeddings.forward(input_ids)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( - input_ids, - &*cache, - DType::F32, - self.layers[0].n_head, - )?; + let mask = CausalMasker.make_causal_mask_matrix(input_ids, &*cache, DType::F32)?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; let residual = &xs; diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index bbae3bedd..9717163ac 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -373,7 +373,7 @@ impl ModelWeights { let (_b_sz, seq_len) = input_ids.dims2()?; let mut xs = self.tok_embeddings.forward(input_ids)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -381,7 +381,6 @@ impl ModelWeights { .unwrap_or(&*cache as &dyn PastKvLenCache), Some(self.max_seq_len), DType::F32, - self.layers[0].n_head, )?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { diff --git a/mistralrs-core/src/models/quantized_qwen2.rs b/mistralrs-core/src/models/quantized_qwen2.rs index 7278c8329..8a60d6208 100644 --- a/mistralrs-core/src/models/quantized_qwen2.rs +++ b/mistralrs-core/src/models/quantized_qwen2.rs @@ -382,14 +382,13 @@ impl ModelWeights { ) -> Result { let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( x, metadata .as_ref() .map(|(_, _)| &start_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), DType::F32, - self.layers[0].n_head, )?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { diff --git a/mistralrs-core/src/models/quantized_starcoder2.rs b/mistralrs-core/src/models/quantized_starcoder2.rs index 4510f42d4..09971e9b1 100644 --- a/mistralrs-core/src/models/quantized_starcoder2.rs +++ b/mistralrs-core/src/models/quantized_starcoder2.rs @@ -369,14 +369,13 @@ impl ModelWeights { let (_b_sz, seq_len) = input_ids.dims2()?; let mut xs = self.tok_embeddings.forward(input_ids)?; let mut cache = self.cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), DType::F32, - self.layers[0].n_head, )?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index 7d3cf7b03..a71540c33 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -538,7 +538,7 @@ impl Model { ) -> Result { let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -546,7 +546,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), Some(self.sliding_window), xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index 458c48cea..3a3f2c7a7 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -526,7 +526,7 @@ impl Model { let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -534,7 +534,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/paged_attention/layers/paged_attention.rs b/mistralrs-core/src/paged_attention/layers/paged_attention.rs index 78707d350..9518e7e9f 100644 --- a/mistralrs-core/src/paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/paged_attention/layers/paged_attention.rs @@ -99,7 +99,7 @@ impl PagedAttention { Some(sc) => ((att / sc)?.tanh()? * sc)?, }; - let att = att.broadcast_add(mask)?; + let att = att.broadcast_add(mask.unsqueeze(0)?.unsqueeze(0)?)?; let att = candle_nn::ops::softmax_last_dim(&att)?; if key_value_heads != attention_heads { let value_repeat = if key_value_heads == 1 { diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index c4a86ce08..9b43a553b 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -510,14 +510,13 @@ impl LLaVALLM for Llama { ) -> Result { let mut x = input_embed; let mut cache = self.kv_cache.lock(); - let mask = CausalMasker.make_causal_mask_as_attn_bias( + let mask = CausalMasker.make_causal_mask_matrix( input_ids, metadata .as_ref() .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) .unwrap_or(&*cache as &dyn PastKvLenCache), x.dtype(), - self.blocks[0].attn.num_attention_heads, )?; for (block_idx, block) in self.blocks.iter().enumerate() { x = self.mapper.map(x, block_idx)?; diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index c14928df7..d16391394 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -510,7 +510,7 @@ impl Model { ) -> Result { let mut xs = input_embeds; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -518,7 +518,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/vision_models/mllama/text.rs b/mistralrs-core/src/vision_models/mllama/text.rs index ccc029b48..fce58ddf2 100644 --- a/mistralrs-core/src/vision_models/mllama/text.rs +++ b/mistralrs-core/src/vision_models/mllama/text.rs @@ -647,11 +647,10 @@ impl MLlamaTextModel { let mut hidden_states = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let self_mask = CausalMasker.make_causal_mask_as_attn_bias( + let self_mask = CausalMasker.make_causal_mask_matrix( input_ids, &seqlen_offsets as &dyn PastKvLenCache, hidden_states.dtype(), - self.cfg.num_attn_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs index 50a707a30..a8cb176e1 100644 --- a/mistralrs-core/src/vision_models/phi3.rs +++ b/mistralrs-core/src/vision_models/phi3.rs @@ -1067,7 +1067,7 @@ impl Model { self.embed_tokens.forward(input_ids)? }; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, metadata .as_ref() @@ -1075,7 +1075,6 @@ impl Model { .unwrap_or(&*cache as &dyn PastKvLenCache), self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/vision_models/qwen2vl/mod.rs b/mistralrs-core/src/vision_models/qwen2vl/mod.rs index bb4774f92..cf9db9fed 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/mod.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/mod.rs @@ -281,12 +281,11 @@ impl Qwen2VLModel { context_lens: Vec<(usize, usize)>, flash_params: &FlashParams, ) -> Result { - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &seqlen_offsets as &dyn PastKvLenCache, self.text.cfg.sliding_window, self.text.dtype, - self.text.cfg.num_attn_heads, )?; let input_embeds = if pixel_values.is_some() || pixel_values_videos.is_some() { diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 7de3ec675..986eee9b4 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -622,12 +622,8 @@ impl XLoraModel { self.cache.lock() }; let xs = self.embed_tokens.forward(input_ids)?; - let attention_mask = CausalMasker.make_causal_mask_as_attn_bias( - input_ids, - &*cache, - xs.dtype(), - self.layers[0].self_attn.num_heads, - )?; + let attention_mask = + CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/xlora_models/gemma2.rs b/mistralrs-core/src/xlora_models/gemma2.rs index 94a95a94d..dfd558523 100644 --- a/mistralrs-core/src/xlora_models/gemma2.rs +++ b/mistralrs-core/src/xlora_models/gemma2.rs @@ -667,20 +667,14 @@ impl Model { } else { self.cache.lock() }; - let attention_mask = CausalMasker.make_causal_mask_as_attn_bias( + let attention_mask = + CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?; + let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, + Some(self.sliding_window), xs.dtype(), - self.layers[0].self_attn.num_heads, )?; - let sliding_attention_mask = CausalMasker - .make_causal_mask_with_sliding_window_as_attn_bias( - input_ids, - &*cache, - Some(self.sliding_window), - xs.dtype(), - self.layers[0].self_attn.num_heads, - )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; xs = layer.forward( diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index adab7d661..8092ade7a 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -460,12 +460,7 @@ impl XLoraLlama { } else { self.kv_cache.lock() }; - let mask = CausalMasker.make_causal_mask_as_attn_bias( - input_ids, - &*cache, - x.dtype(), - self.blocks[0].attn.num_attention_heads, - )?; + let mask = CausalMasker.make_causal_mask_matrix(input_ids, &*cache, x.dtype())?; for (block_idx, block) in self.blocks.iter().enumerate() { x = self.mapper.map(x, block_idx)?; x = block.forward( diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index b22960a36..ac0c04f95 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -617,12 +617,11 @@ impl XLoraModel { self.cache.lock() }; let mut xs = self.embed_tokens.forward(input_ids)?; - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 67565f650..35e7e9e6b 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -758,12 +758,11 @@ impl XLoraModel { self.cache.lock() }; let mut xs = self.embed_tokens.forward(input_ids)?; - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index c95baec20..97ab9e68f 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -582,12 +582,7 @@ impl Model { } else { self.cache.lock() }; - let mask = CausalMasker.make_causal_mask_as_attn_bias( - input_ids, - &*cache, - xs.dtype(), - self.layers[0].self_attn.num_heads, - )?; + let mask = CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?; for (i, layer) in self.layers.iter().enumerate() { xs = self.mapper.map(xs, i)?; xs = layer.forward( diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index e7335e580..2816cd8cb 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -546,12 +546,11 @@ impl Model { } else { self.cache.lock() }; - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() { diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index 656cc53d2..4a5362738 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -837,12 +837,7 @@ impl ModelWeights { } else { self.cache.lock() }; - let mask = CausalMasker.make_causal_mask_as_attn_bias( - x, - &*cache, - DType::F32, - self.layers[0].n_head, - )?; + let mask = CausalMasker.make_causal_mask_matrix(x, &*cache, DType::F32)?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { layer_in = mapper.map(layer_in, i)?; diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index e63ddfe57..e212f8603 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -426,12 +426,11 @@ impl ModelWeights { } else { self.cache.lock() }; - let mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, Some(self.max_seq_len), DType::F32, - self.layers[0].n_head, )?; for (i, layer) in self.layers.iter().enumerate() { if let Some(ref mapper) = self.mapper { diff --git a/mistralrs-core/src/xlora_models/starcoder2.rs b/mistralrs-core/src/xlora_models/starcoder2.rs index 097777142..b60d8e03c 100644 --- a/mistralrs-core/src/xlora_models/starcoder2.rs +++ b/mistralrs-core/src/xlora_models/starcoder2.rs @@ -600,12 +600,11 @@ impl Model { } else { self.cache.lock() }; - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( input_ids, &*cache, self.sliding_window, xs.dtype(), - self.layers[0].self_attn.num_heads, )?; for (i, layer) in self.layers.iter().enumerate() {