From d9e7e7b7e33cae32fbf91b1af944bdb14be0771b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 22 Nov 2024 18:40:20 -0500 Subject: [PATCH 1/2] Use sdpa to benefit from cublaslt for vllama --- mistralrs-core/src/attention.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index 936813aca..c496ca015 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -92,7 +92,7 @@ fn naive_sdpa( head_dim: usize, sdpa_params: &SdpaParams, ) -> Result { - if let Some(mask) = mask { + if mask.is_some_and(|mask| mask.rank() == 2) { let mut att = MatMul.matmul(q, &k.t()?)?; if let Some(softcap) = sdpa_params.softcap { att = (att / softcap as f64)?; @@ -100,7 +100,11 @@ fn naive_sdpa( att = (att * softcap as f64)?; } - let att = candle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?; + let att = candle_nn::ops::attn_softmax_last_dim( + &att, + mask.unwrap(), + 1. / (head_dim as f32).sqrt(), + )?; MatMul.matmul(&att, v) } else { let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?; From ee67853816b4087a9c507c51bb387859e19b24ff Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 22 Nov 2024 22:33:09 -0500 Subject: [PATCH 2/2] Handle different ranks --- mistralrs-core/src/attention.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index c496ca015..abb5bf9d1 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -184,7 +184,15 @@ impl Sdpa { let k = k.flatten(0, 1)?; let q = q.flatten(0, 1)?; let v = v.flatten(0, 1)?; - let attention_bias = mask.cloned(); + let attention_bias = match mask { + Some(mask) if mask.rank() == 3 => Some(mask.clone()), + Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?), + Some(mask) if mask.rank() == 2 => Some(mask.unsqueeze(0)?), + Some(mask) => { + candle_core::bail!("cublaslt attn mask: rank must be 3, 4, or 2") + } + None => None, + }; // If attention_bias is set, we fuse the add by giving it as the output matrix // and setting beta to 1.0