Skip to content

Commit

Permalink
More efficient vision attnmask creation
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 7, 2024
1 parent 6018fab commit f4dd913
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions mistralrs-core/src/vision_models/qwen2vl/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ impl VisionAttention {
head_dim: dim / num_heads,
})
}
fn forward(&self, xs: &Tensor, cu_seqlens: &[u32], rotary_pos_emb: &Tensor) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
rotary_pos_emb: &Tensor,
) -> Result<Tensor> {
let seq_len = xs.dim(0)?;
let (q, k, v) = {
let qkv = self
Expand All @@ -126,23 +131,16 @@ impl VisionAttention {
.squeeze(0)?
.to_dtype(q.dtype())?;

let mut attention_mask = Tensor::full(f32::MIN, (1, seq_len, seq_len), q.device())?;
for i in 1..cu_seqlens.len() {
let a = cu_seqlens[i - 1] as usize;
let b = cu_seqlens[i] as usize;
attention_mask = attention_mask.slice_assign(
&[&.., &(a..b), &(a..b)],
&Tensor::zeros((1, b - a, b - a), DType::F32, q.device())?,
)?;
}

let q = q.transpose(0, 1)?.contiguous()?;
let k = k.transpose(0, 1)?.contiguous()?;
let v = v.transpose(0, 1)?.contiguous()?;

let att = {
let att = (q.matmul(&k.transpose(1, 2)?)? / (self.head_dim as f64).sqrt())?;
let att = att.broadcast_add(&attention_mask.to_dtype(q.dtype())?)?;
let att = match attention_mask {
Some(m) => att.broadcast_add(m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v)?
Expand Down Expand Up @@ -180,11 +178,16 @@ impl VisionBlock {
})
}

fn forward(&self, xs: &Tensor, cu_seqlens: &[u32], rotary_pos_emb: &Tensor) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
rotary_pos_emb: &Tensor,
) -> Result<Tensor> {
let xs = (xs
+ self
.attn
.forward(&self.norm1.forward(xs)?, cu_seqlens, rotary_pos_emb)?)?;
.forward(&self.norm1.forward(xs)?, attention_mask, rotary_pos_emb)?)?;
&xs + self.mlp.forward(&self.norm2.forward(&xs)?)?
}
}
Expand Down Expand Up @@ -345,8 +348,27 @@ impl Qwen2VLVisionModel {
.pad_with_zeros(0, 1, 0)?
.to_vec1::<u32>()?;

let seq_len = xs.dim(0)?;
let attention_mask = match &cu_seqlens[..] {
&[0, len] if len == seq_len as u32 => None,
cu_seqlens => {
let mut attention_mask =
Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
.to_dtype(xs.dtype())?;
for i in 1..cu_seqlens.len() {
let a = cu_seqlens[i - 1] as usize;
let b = cu_seqlens[i] as usize;
attention_mask = attention_mask.slice_assign(
&[&.., &(a..b), &(a..b)],
&Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
)?;
}
Some(attention_mask)
}
};

for blk in &self.blocks {
xs = blk.forward(&xs, &cu_seqlens, &rotary_pos_emb)?;
xs = blk.forward(&xs, attention_mask.as_ref(), &rotary_pos_emb)?;
}

self.patch_merger.forward(&xs)
Expand Down

0 comments on commit f4dd913

Please sign in to comment.