Skip to content

Commit

Permalink
Correctly do the text model rope
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 29, 2024
1 parent 1432027 commit 0d89101
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 48 deletions.
94 changes: 94 additions & 0 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
cublaslt::CUBLASLT_HANDLE,
gguf::Content,
models::llama,
ops::SplitOp,
vision_models::mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
INHIBIT_GEMM_F16,
};
Expand Down Expand Up @@ -622,6 +623,99 @@ impl Llama3RotaryEmbedding {
}
}

// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
#[derive(Debug, Clone)]
pub struct Qwen2VLRotaryEmbedding {
cos: Tensor,
sin: Tensor,
mrope_section: Vec<usize>,
}

impl Qwen2VLRotaryEmbedding {
pub fn new(
base: f32,
head_dim: usize,
max_position_embeddings: usize,
device: &Device,
dtype: DType,
mrope_section: Vec<usize>,
) -> Result<Self> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
.collect();
let theta_len = theta.len();
let theta = Tensor::from_vec(theta, (1, 1, theta_len), device)?
.to_dtype(DType::F32)?
.repeat((3, 1, 1))?;
let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)?
.to_dtype(DType::F32)?
.reshape((1, max_position_embeddings, 1))?
.repeat((3, 1, 1))?
.matmul(&theta)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
cos,
sin,
mrope_section,
})
}

fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}

// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
pub fn forward(&self, positions: &[usize], q: &mut Tensor, k: &mut Tensor) -> Result<()> {
let mrope_scaling: Vec<_> =
[self.mrope_section.clone(), self.mrope_section.clone()].concat();
let cos = Tensor::cat(
&self
.cos
.split(&mrope_scaling, D::Minus1)?
.into_iter()
.enumerate()
.map(|(i, m)| m.i(i % 3))
.collect::<Result<Vec<_>>>()?,
D::Minus1,
)?
.unsqueeze(1)?;
let sin = Tensor::cat(
&self
.sin
.split(&mrope_scaling, D::Minus1)?
.into_iter()
.enumerate()
.map(|(i, m)| m.i(i % 3))
.collect::<Result<Vec<_>>>()?,
D::Minus1,
)?
.unsqueeze(1)?;

let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
for (i, offset) in positions.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let q = q.i(i)?.unsqueeze(0)?.contiguous()?;
let k = k.i(i)?.unsqueeze(0)?.contiguous()?;
let q_embed = ((&q * &cos)? + (Self::rotate_half(&q)? * &sin))?;
let k_embed = ((&k * &cos)? + (Self::rotate_half(&k)? * &sin))?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}

*q = Tensor::cat(&q_embeds, 0)?;
*k = Tensor::cat(&k_embeds, 0)?;
Ok(())
}
}

/// Matrix multiplication, configurable to be via f16 (to use the faster GEMM kernels) optionally.
pub struct MatMul;

Expand Down
37 changes: 35 additions & 2 deletions mistralrs-core/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use candle_core::{
backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape,
Tensor, WithDType, D,
backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
Result, Shape, Tensor, WithDType, D,
};

use std::{
Expand Down Expand Up @@ -613,6 +613,23 @@ impl RepeatInterleaveOp for Tensor {
}
}

pub trait SplitOp {
fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
}

impl SplitOp for Tensor {
fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
let dim = dim.to_index(self.shape(), "split")?;
let mut split_res = Vec::new();
let mut index = 0;
for split in splits {
split_res.push(self.narrow(dim, index, *split)?);
index = *split;
}
Ok(split_res)
}
}

mod tests {
#[test]
fn test_topk() {
Expand Down Expand Up @@ -855,4 +872,20 @@ mod tests {

Ok(())
}

#[test]
fn test_repeat_interleave_flat() -> candle_core::Result<()> {
use crate::ops::RepeatInterleaveOp;
use candle_core::{Device, Tensor};

let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;

let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
assert_eq!(
repeat_interleaved.to_vec1::<f32>()?,
vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
);

Ok(())
}
}
8 changes: 2 additions & 6 deletions mistralrs-core/src/vision_models/qwen2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use candle_core::Result;
use candle_nn::{Embedding, VarBuilder};
use candle_nn::VarBuilder;
use config::Config;
use text::Qwen2VLTextModel;
use vision::Qwen2VLVisionModel;

use crate::{
layers::{RmsNorm, RotaryEmbedding},
paged_attention::AttentionImplementation,
pipeline::NormalLoadingMetadata,
};
use crate::{paged_attention::AttentionImplementation, pipeline::NormalLoadingMetadata};

mod config;
mod text;
Expand Down
63 changes: 23 additions & 40 deletions mistralrs-core/src/vision_models/qwen2/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use candle_nn::{Activation, Embedding, Linear, Module, VarBuilder};
use crate::{
attention::SdpaParams,
dummy_paged_attention::ModelConfigMetadata,
layers::{CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
layers::{CausalMasker, Qwen2VLRotaryEmbedding, RmsNorm, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, PagedAttention},
pipeline::{
Expand Down Expand Up @@ -56,14 +56,14 @@ struct Attention {
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}

impl Attention {
fn new(

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

associated items `new` and `forward` are never used

Check failure on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Clippy

associated items `new` and `forward` are never used

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Docs

associated items `new` and `forward` are never used

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 65 in mistralrs-core/src/vision_models/qwen2/text.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

associated items `new` and `forward` are never used
rotary_emb: Arc<RotaryEmbedding>,
rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
paged_attn: Option<PagedAttention>,
Expand Down Expand Up @@ -102,7 +102,6 @@ impl Attention {
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
Expand All @@ -113,37 +112,25 @@ impl Attention {
let k = self.k_proj.forward(&xs)?;
let v = self.v_proj.forward(&xs)?;

let mut q = q.reshape((b_sz * q_len, self.num_heads, self.head_dim))?;
let mut k = k.reshape((b_sz * q_len, self.num_kv_heads, self.head_dim))?;
let v = if q_len != 1 {
v.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
let (mut q, mut k, v) = if q_len != 1 {
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
// Optimization for seqlen = 1, avoid transpose and just modify reshape dims
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
(q, k, v)
};

self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

if q.rank() == 3 && q_len != 1 {
q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
} else if q.rank() == 3 {
// Optimization for seqlen = 1, avoid transpose and just modify reshape dims
q = q
.reshape((b_sz, self.num_heads, q_len, self.head_dim))?
.contiguous()?;
k = k
.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
.contiguous()?;
}
self.rotary_emb.forward(seqlen_offsets, &mut q, &mut k)?;

let mut attn_output = match &self.paged_attn {
Some(paged_attn) => {
Expand Down Expand Up @@ -191,7 +178,7 @@ pub struct DecoderLayer {

impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
paged_attn: Option<PagedAttention>,
Expand Down Expand Up @@ -219,7 +206,6 @@ impl DecoderLayer {
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
Expand All @@ -230,7 +216,6 @@ impl DecoderLayer {
&xs,
attention_mask,
seqlen_offsets,
start_offsets_kernel,
kv_cache,
metadata,
flash_params,
Expand Down Expand Up @@ -258,7 +243,7 @@ impl Qwen2VLTextModel {
pub fn new(
cfg: &Config,
vb: VarBuilder,
is_gptx: bool,
_is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
Expand All @@ -274,13 +259,13 @@ impl Qwen2VLTextModel {
let device = &normal_loading_metadata.real_device;
ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
Arc::new(Qwen2VLRotaryEmbedding::new(
cfg.rope_theta as f32,
head_dim,
cfg.max_position_embeddings,
device,
is_gptx,
vb_m.dtype(),
cfg.rope_scaling.mrope_section.clone(),
)?),
);
}
Expand Down Expand Up @@ -338,7 +323,6 @@ impl Qwen2VLTextModel {
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
context_lens: Vec<(usize, usize)>,
mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
Expand All @@ -363,7 +347,6 @@ impl Qwen2VLTextModel {
.map(|m| m.to_device(xs.device()).unwrap())
.as_ref(),
seqlen_offsets,
start_offsets_kernel.clone(),
&mut cache[i],
metadata
.as_mut()
Expand Down

0 comments on commit 0d89101

Please sign in to comment.