diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index f6c6fefd2a..92c3864bd4 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -23,8 +23,8 @@ pub use normal_loaders::{ }; pub use vision_loaders::{ - Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, VLlamaLoader, VisionLoaderType, - VisionModel, VisionModelLoader, + Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, Qwen2VLLoader, VLlamaLoader, + VisionLoaderType, VisionModel, VisionModelLoader, }; pub use diffusion_loaders::{ diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs index a41ca568bd..d20cd32946 100644 --- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs @@ -30,6 +30,7 @@ use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3}; use crate::vision_models::phi3_inputs_processor::Phi3Processor; use crate::vision_models::preprocessor_config::PreProcessorConfig; use crate::vision_models::processor_config::ProcessorConfig; +use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel}; pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin { // pixel_values and pixel_attention_mask only specified for prompt seqs @@ -89,6 +90,8 @@ pub enum VisionLoaderType { LLaVA, #[serde(rename = "vllama")] VLlama, + #[serde(rename = "qwen2vl")] + Qwen2VL, } impl FromStr for VisionLoaderType { @@ -100,7 +103,8 @@ impl FromStr for VisionLoaderType { "llava_next" => Ok(Self::LLaVANext), "llava" => Ok(Self::LLaVA), "vllama" => Ok(Self::VLlama), - a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`.")), + "qwen2vl" => Ok(Self::Qwen2VL), + a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`.")), } } } @@ -448,3 +452,59 @@ impl IsqModelLoader for VLlamaLoader { ]) } } + +// ======================== Qwen2VL Loader + +/// [`VisionLoader`] for an Qwen2-VL model. +/// +/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html +pub struct Qwen2VLLoader; + +impl VisionModelLoader for Qwen2VLLoader { + fn load( + &self, + config: &str, + _use_flash_attn: bool, + vb: VarBuilder, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result> { + let config: Qwen2VLConfig = serde_json::from_str(config)?; + Ok(Box::new(Qwen2VLModel::new( + &config, + vb, + self.is_gptx(), + normal_loading_metadata, + attention_mechanism, + )?)) + } + fn is_gptx(&self) -> bool { + true + } + fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result> { + let config: Qwen2VLConfig = serde_json::from_str(config)?; + Ok(Box::new(config)) + } + fn get_processor( + &self, + _model_config: &str, + _processor_config: Option, + _preprocessor_config: PreProcessorConfig, + ) -> Arc { + Arc::new(MLlamaProcessor::new()) + } + fn get_total_device_mapping_num_layers(&self, config: &str) -> Result { + let config: Qwen2VLConfig = serde_json::from_str(config)?; + // We only apply device mapping to text model + Ok(config.num_hidden_layers) + } + fn supports_paged_attention(&self) -> bool { + false + } +} + +impl IsqModelLoader for Qwen2VLLoader { + fn isq_layer_regexes(&self, _config: &str) -> Result> { + todo!() + } +} diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index ad6927bec2..b373c4ae11 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -34,8 +34,8 @@ pub use loaders::{ Gemma2Loader, GemmaLoader, Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, - Phi3_5MoELoader, PrettyName, QuantizationKind, Qwen2Loader, Starcoder2Loader, TokenSource, - VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader, + Phi3_5MoELoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Starcoder2Loader, + TokenSource, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader, }; use mistralrs_quant::IsqType; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index e8b8083920..02b43e9fe6 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -4,7 +4,7 @@ use super::{ get_model_paths, get_xlora_paths, AdapterActivationMixin, AnyMoePipelineMixin, Cache, CacheManager, CacheManagerMixin, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, Processor, - TokenSource, VLlamaLoader, VisionModel, VisionModelLoader, XLoraPaths, + Qwen2VLLoader, TokenSource, VLlamaLoader, VisionModel, VisionModelLoader, XLoraPaths, }; use super::{Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, VisionLoaderType}; use crate::aici::bintokens::build_tok_trie; @@ -118,6 +118,7 @@ impl VisionLoaderBuilder { VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader), VisionLoaderType::LLaVA => Box::new(LLaVALoader), VisionLoaderType::VLlama => Box::new(VLlamaLoader), + VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader), }; Box::new(VisionLoader { inner: loader, diff --git a/mistralrs-core/src/vision_models/mod.rs b/mistralrs-core/src/vision_models/mod.rs index 734368a09a..7b93ae7a01 100644 --- a/mistralrs-core/src/vision_models/mod.rs +++ b/mistralrs-core/src/vision_models/mod.rs @@ -12,7 +12,7 @@ pub(crate) mod phi3; pub(crate) mod phi3_inputs_processor; pub(crate) mod preprocessor_config; pub(crate) mod processor_config; -pub(crate) mod qwen2; +pub(crate) mod qwen2vl; pub(crate) use llava::llava15; pub(crate) use llava::llava_inputs_processor; pub(crate) use llava::llava_next; diff --git a/mistralrs-core/src/vision_models/qwen2/mod.rs b/mistralrs-core/src/vision_models/qwen2/mod.rs deleted file mode 100644 index 29ebd79919..0000000000 --- a/mistralrs-core/src/vision_models/qwen2/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -use candle_core::Result; -use candle_nn::VarBuilder; -use config::Config; -use text::Qwen2VLTextModel; -use vision::Qwen2VLVisionModel; - -use crate::{paged_attention::AttentionImplementation, pipeline::NormalLoadingMetadata}; - -mod config; -mod text; -mod vision; - -pub struct Qwen2VLModel { - model: Qwen2VLTextModel, - vision: Qwen2VLVisionModel, -} - -impl Qwen2VLModel { - fn new( - cfg: &Config, - vb: VarBuilder, - is_gptx: bool, - normal_loading_metadata: NormalLoadingMetadata, - attention_mechanism: AttentionImplementation, - ) -> Result { - let model = Qwen2VLTextModel::new( - cfg, - vb.clone(), - is_gptx, - normal_loading_metadata, - attention_mechanism, - )?; - let vision = Qwen2VLVisionModel::new(&cfg.vision_config, vb.pp("vision"))?; - Ok(Self { model, vision }) - } -} diff --git a/mistralrs-core/src/vision_models/qwen2/config.rs b/mistralrs-core/src/vision_models/qwen2vl/config.rs similarity index 95% rename from mistralrs-core/src/vision_models/qwen2/config.rs rename to mistralrs-core/src/vision_models/qwen2vl/config.rs index 0066570308..8af1f878e3 100644 --- a/mistralrs-core/src/vision_models/qwen2/config.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/config.rs @@ -39,4 +39,6 @@ pub struct Config { pub max_window_layers: usize, pub vision_config: VisionConfig, pub rope_scaling: MRopeScaling, + pub image_token_id: usize, + pub video_token_id: usize, } diff --git a/mistralrs-core/src/vision_models/qwen2vl/mod.rs b/mistralrs-core/src/vision_models/qwen2vl/mod.rs new file mode 100644 index 0000000000..7d2ed2ce1e --- /dev/null +++ b/mistralrs-core/src/vision_models/qwen2vl/mod.rs @@ -0,0 +1,202 @@ +use std::{any::Any, sync::Arc}; + +use candle_core::{Context, Device, Result, Tensor, D}; +use candle_nn::VarBuilder; +use mistralrs_quant::QuantMethod; +use text::Qwen2VLTextModel; +use vision::Qwen2VLVisionModel; + +use crate::{ + amoe::AnyMoeBaseModelMixin, + device_map::DeviceMapper, + dummy_paged_attention::ModelConfigMetadata, + layers::CausalMasker, + layers_masker::PastKvLenCache, + paged_attention::AttentionImplementation, + pipeline::{ + text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, + Cache, IsqModel, NormalLoadingMetadata, VisionModel, + }, +}; + +mod config; +mod text; +mod vision; + +pub(crate) use config::Config; + +pub struct Qwen2VLModel { + text: Qwen2VLTextModel, + vision: Qwen2VLVisionModel, + image_token_id: usize, + video_token_id: usize, +} + +impl Qwen2VLModel { + pub fn new( + cfg: &Config, + vb: VarBuilder, + is_gptx: bool, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result { + if cfg.use_sliding_window { + // TODO! + candle_core::bail!("Sliding window is unsupported for now!"); + } + let text = Qwen2VLTextModel::new( + cfg, + vb.clone(), + is_gptx, + normal_loading_metadata, + attention_mechanism, + )?; + let vision = Qwen2VLVisionModel::new(&cfg.vision_config, vb.pp("vision"))?; + Ok(Self { + text, + vision, + image_token_id: cfg.image_token_id, + video_token_id: cfg.video_token_id, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + pixel_values: Option, + pixel_values_videos: Option, + image_grid_thw: Option, + video_grid_thw: Option, + seqlen_offsets: &[usize], + context_lens: Vec<(usize, usize)>, + metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + let (input_embeds, attention_mask) = if pixel_values.is_some() + || pixel_values_videos.is_some() + { + let mut xs = self.text.embed_tokens(input_ids)?; + + let cache = self.text.cache.lock(); + let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( + input_ids, + metadata + .as_ref() + .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) + .unwrap_or(&*cache as &dyn PastKvLenCache), + self.text.cfg.sliding_window, + xs.dtype(), + self.text.cfg.num_attn_heads, + )?; + + if let Some(pixel_values) = pixel_values { + let image_embeds = self.vision.forward( + &pixel_values, + &image_grid_thw.context("pixel_values require image_grid_thw")?, + )?; + let image_mask = input_ids + .eq(self.image_token_id as f64)? + .unsqueeze(D::Minus1)? + .broadcast_as(xs.shape())?; + xs = image_mask.where_cond(&image_embeds, &xs)?; + } + + if let Some(pixel_values_videos) = pixel_values_videos { + let video_embeds = self.vision.forward( + &pixel_values_videos, + &video_grid_thw.context("pixel_values_videos require video_grid_thw")?, + )?; + let video_mask = input_ids + .eq(self.video_token_id as f64)? + .unsqueeze(D::Minus1)? + .broadcast_as(xs.shape())?; + xs = video_mask.where_cond(&video_embeds, &xs)?; + } + + (xs, attention_mask) + } else { + let xs = self.text.embed_tokens(input_ids)?; + (xs, None) + }; + + self.text.forward_embeds( + input_embeds, + attention_mask.as_ref(), + seqlen_offsets, + context_lens, + metadata, + flash_params, + ) + } +} + +pub(crate) struct Qwen2VLVisionSpecificArgs { + image_grid_thw: Option, + video_grid_thw: Option, + pixel_values_video: Option, +} + +impl VisionModel for Qwen2VLModel { + fn forward( + &self, + input_ids: &Tensor, + pixel_values: Option, + seqlen_offsets: &[usize], + _start_offsets_kernel: Tensor, + context_lens: Vec<(usize, usize)>, + _position_ids: Vec, + model_specific_args: Box, + metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + let Qwen2VLVisionSpecificArgs { + image_grid_thw, + video_grid_thw, + pixel_values_video, + } = *model_specific_args + .downcast() + .expect("Cannot downcast into `Qwen2VLVisionSpecificArgs`"); + self.forward( + input_ids, + pixel_values, + pixel_values_video, + image_grid_thw, + video_grid_thw, + seqlen_offsets, + context_lens, + metadata, + flash_params, + ) + } + fn cache(&self) -> &Cache { + &self.text.cache + } + fn device(&self) -> &Device { + &self.text.device + } + fn max_seq_len(&self) -> usize { + self.text.max_seq_len + } + fn has_conv2d(&self) -> bool { + true + } + fn config(&self) -> &ModelConfigMetadata { + &self.text.cfg + } +} + +impl IsqModel for Qwen2VLModel { + fn get_layers( + &mut self, + ) -> ( + Vec<(&mut Arc, Option)>, + &dyn DeviceMapper, + ) { + todo!() + } + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + todo!() + } +} + +impl AnyMoeBaseModelMixin for Qwen2VLModel {} diff --git a/mistralrs-core/src/vision_models/qwen2/text.rs b/mistralrs-core/src/vision_models/qwen2vl/text.rs similarity index 93% rename from mistralrs-core/src/vision_models/qwen2/text.rs rename to mistralrs-core/src/vision_models/qwen2vl/text.rs index 95e20ff968..feebc08875 100644 --- a/mistralrs-core/src/vision_models/qwen2/text.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/text.rs @@ -1,13 +1,12 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{Result, Tensor}; +use candle_core::{Device, Result, Tensor}; use candle_nn::{Activation, Embedding, Linear, Module, VarBuilder}; use crate::{ attention::SdpaParams, dummy_paged_attention::ModelConfigMetadata, - layers::{CausalMasker, Qwen2VLRotaryEmbedding, RmsNorm, Sdpa}, - layers_masker::PastKvLenCache, + layers::{Qwen2VLRotaryEmbedding, RmsNorm, Sdpa}, paged_attention::{AttentionImplementation, PagedAttention}, pipeline::{ extract_logits, @@ -234,9 +233,10 @@ pub struct Qwen2VLTextModel { norm: RmsNorm, layers: Vec, lm_head: Linear, - cache: Cache, - max_seq_len: usize, - cfg: ModelConfigMetadata, + pub(super) cache: Cache, + pub(super) cfg: ModelConfigMetadata, + pub(super) device: Device, + pub(super) max_seq_len: usize, } impl Qwen2VLTextModel { @@ -316,29 +316,24 @@ impl Qwen2VLTextModel { sliding_window: cfg.sliding_window, head_dim: None, }, + device: vb.device().clone(), }) } - pub fn forward( + pub fn embed_tokens(&self, input_ids: &Tensor) -> Result { + self.embed_tokens.forward(input_ids) + } + + pub fn forward_embeds( &self, - input_ids: &Tensor, + mut xs: Tensor, + attention_mask: Option<&Tensor>, seqlen_offsets: &[usize], context_lens: Vec<(usize, usize)>, mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, ) -> Result { - let mut xs = self.embed_tokens.forward(input_ids)?; let mut cache = self.cache.lock(); - let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias( - input_ids, - metadata - .as_ref() - .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) - .unwrap_or(&*cache as &dyn PastKvLenCache), - self.cfg.sliding_window, - xs.dtype(), - self.layers[0].self_attn.num_heads, - )?; for (i, layer) in self.layers.iter().enumerate() { xs = layer.forward( &xs, diff --git a/mistralrs-core/src/vision_models/qwen2/vision.rs b/mistralrs-core/src/vision_models/qwen2vl/vision.rs similarity index 95% rename from mistralrs-core/src/vision_models/qwen2/vision.rs rename to mistralrs-core/src/vision_models/qwen2vl/vision.rs index d48d595126..7ecca1ab20 100644 --- a/mistralrs-core/src/vision_models/qwen2/vision.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/vision.rs @@ -1,5 +1,5 @@ use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, Activation, LayerNorm, Linear, Module, Sequential, VarBuilder}; +use candle_nn::{layer_norm, Activation, LayerNorm, Linear, Module, VarBuilder}; use crate::{ layers::{Conv3dConfig, Conv3dNoBias}, @@ -9,7 +9,6 @@ use crate::{ use super::config::VisionConfig; struct PatchEmbed { - cfg: VisionConfig, proj: Conv3dNoBias, in_channels: usize, patch_size: usize, @@ -24,7 +23,6 @@ impl PatchEmbed { candle_core::bail!("Only support temporal patch size of 2"); } Ok(Self { - cfg: cfg.clone(), proj: Conv3dNoBias::new( cfg.in_channels, cfg.embed_dim, @@ -194,7 +192,8 @@ impl VisionBlock { struct PatchMerger { ln_q: LayerNorm, - mlp: Sequential, + mlp0: Linear, + mlp2: Linear, hidden_size: usize, } @@ -206,17 +205,12 @@ impl PatchMerger { vb: VarBuilder, ) -> Result { let hidden_size = context_dim * spatial_merge_size.pow(2); - let mut mlp = candle_nn::seq(); - mlp = mlp.add(candle_nn::linear_no_bias( - hidden_size, - hidden_size, - vb.pp("mlp.0"), - )?); - mlp = mlp.add(candle_nn::Activation::Gelu); - mlp = mlp.add(candle_nn::linear_no_bias(hidden_size, dim, vb.pp("mlp.2"))?); + let mlp0 = candle_nn::linear_no_bias(hidden_size, hidden_size, vb.pp("mlp.0"))?; + let mlp2 = candle_nn::linear_no_bias(hidden_size, dim, vb.pp("mlp.2"))?; Ok(Self { ln_q: layer_norm(context_dim, 1e-6, vb.pp("ln_q"))?, - mlp, + mlp0, + mlp2, hidden_size, }) } @@ -224,7 +218,9 @@ impl PatchMerger { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.ln_q)? .reshape(((), self.hidden_size))? - .apply(&self.mlp) + .apply(&self.mlp0)? + .gelu()? + .apply(&self.mlp2) } } @@ -329,7 +325,7 @@ impl Qwen2VLVisionModel { .flatten_from(1) } - fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result { + pub fn forward(&self, xs: &Tensor, grid_thw: &Tensor) -> Result { let mut xs = self.patch_embed.forward(xs)?; let rotary_pos_emb = self.rot_pos_emb(grid_thw)?;