Skip to content

Commit

Permalink
Use sdpa to benefit from cublaslt for vllama
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 22, 2024
1 parent afc9d41 commit d9e7e7b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,19 @@ fn naive_sdpa(
head_dim: usize,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
if let Some(mask) = mask {
if mask.is_some_and(|mask| mask.rank() == 2) {
let mut att = MatMul.matmul(q, &k.t()?)?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}

let att = candle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?;
let att = candle_nn::ops::attn_softmax_last_dim(
&att,
mask.unwrap(),
1. / (head_dim as f32).sqrt(),
)?;
MatMul.matmul(&att, v)
} else {
let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
Expand Down

0 comments on commit d9e7e7b

Please sign in to comment.