Skip to content

Commit

Permalink
Integrate with the loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 29, 2024
1 parent 0d89101 commit 2dffb9e
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 77 deletions.
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/loaders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ pub use normal_loaders::{
};

pub use vision_loaders::{
Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, VLlamaLoader, VisionLoaderType,
VisionModel, VisionModelLoader,
Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, Qwen2VLLoader, VLlamaLoader,
VisionLoaderType, VisionModel, VisionModelLoader,
};

pub use diffusion_loaders::{
Expand Down
62 changes: 61 additions & 1 deletion mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3};
use crate::vision_models::phi3_inputs_processor::Phi3Processor;
use crate::vision_models::preprocessor_config::PreProcessorConfig;
use crate::vision_models::processor_config::ProcessorConfig;
use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel};

pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
// pixel_values and pixel_attention_mask only specified for prompt seqs
Expand Down Expand Up @@ -89,6 +90,8 @@ pub enum VisionLoaderType {
LLaVA,
#[serde(rename = "vllama")]
VLlama,
#[serde(rename = "qwen2vl")]
Qwen2VL,
}

impl FromStr for VisionLoaderType {
Expand All @@ -100,7 +103,8 @@ impl FromStr for VisionLoaderType {
"llava_next" => Ok(Self::LLaVANext),
"llava" => Ok(Self::LLaVA),
"vllama" => Ok(Self::VLlama),
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`.")),
"qwen2vl" => Ok(Self::Qwen2VL),
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`.")),
}
}
}
Expand Down Expand Up @@ -448,3 +452,59 @@ impl IsqModelLoader for VLlamaLoader {
])
}
}

// ======================== Qwen2VL Loader

/// [`VisionLoader`] for an Qwen2-VL model.
///
/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
pub struct Qwen2VLLoader;

impl VisionModelLoader for Qwen2VLLoader {
fn load(
&self,
config: &str,
_use_flash_attn: bool,
vb: VarBuilder,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Box<dyn VisionModel + Send + Sync>> {
let config: Qwen2VLConfig = serde_json::from_str(config)?;
Ok(Box::new(Qwen2VLModel::new(
&config,
vb,
self.is_gptx(),
normal_loading_metadata,
attention_mechanism,
)?))
}
fn is_gptx(&self) -> bool {
true
}
fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
let config: Qwen2VLConfig = serde_json::from_str(config)?;
Ok(Box::new(config))
}
fn get_processor(
&self,
_model_config: &str,
_processor_config: Option<ProcessorConfig>,
_preprocessor_config: PreProcessorConfig,
) -> Arc<dyn Processor + Send + Sync> {
Arc::new(MLlamaProcessor::new())
}
fn get_total_device_mapping_num_layers(&self, config: &str) -> Result<usize> {
let config: Qwen2VLConfig = serde_json::from_str(config)?;
// We only apply device mapping to text model
Ok(config.num_hidden_layers)
}
fn supports_paged_attention(&self) -> bool {
false
}
}

impl IsqModelLoader for Qwen2VLLoader {
fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
todo!()
}
}
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ pub use loaders::{
Gemma2Loader, GemmaLoader, Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType,
NormalLoadingMetadata, NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader,
Phi3_5MoELoader, PrettyName, QuantizationKind, Qwen2Loader, Starcoder2Loader, TokenSource,
VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
Phi3_5MoELoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Starcoder2Loader,
TokenSource, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
};
use mistralrs_quant::IsqType;
pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
get_model_paths, get_xlora_paths, AdapterActivationMixin, AnyMoePipelineMixin, Cache,
CacheManager, CacheManagerMixin, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin,
Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, Processor,
TokenSource, VLlamaLoader, VisionModel, VisionModelLoader, XLoraPaths,
Qwen2VLLoader, TokenSource, VLlamaLoader, VisionModel, VisionModelLoader, XLoraPaths,
};
use super::{Idefics2Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, VisionLoaderType};
use crate::aici::bintokens::build_tok_trie;
Expand Down Expand Up @@ -118,6 +118,7 @@ impl VisionLoaderBuilder {
VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
VisionLoaderType::LLaVA => Box::new(LLaVALoader),
VisionLoaderType::VLlama => Box::new(VLlamaLoader),
VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
};
Box::new(VisionLoader {
inner: loader,
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/vision_models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) mod phi3;
pub(crate) mod phi3_inputs_processor;
pub(crate) mod preprocessor_config;
pub(crate) mod processor_config;
pub(crate) mod qwen2;
pub(crate) mod qwen2vl;
pub(crate) use llava::llava15;
pub(crate) use llava::llava_inputs_processor;
pub(crate) use llava::llava_next;
Expand Down
36 changes: 0 additions & 36 deletions mistralrs-core/src/vision_models/qwen2/mod.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,6 @@ pub struct Config {
pub max_window_layers: usize,

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

field `max_window_layers` is never read

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

field `max_window_layers` is never read

Check failure on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Clippy

field `max_window_layers` is never read

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Docs

field `max_window_layers` is never read

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

field `max_window_layers` is never read

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

field `max_window_layers` is never read

Check warning on line 39 in mistralrs-core/src/vision_models/qwen2vl/config.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

field `max_window_layers` is never read
pub vision_config: VisionConfig,
pub rope_scaling: MRopeScaling,
pub image_token_id: usize,
pub video_token_id: usize,
}
202 changes: 202 additions & 0 deletions mistralrs-core/src/vision_models/qwen2vl/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::{any::Any, sync::Arc};

use candle_core::{Context, Device, Result, Tensor, D};
use candle_nn::VarBuilder;
use mistralrs_quant::QuantMethod;
use text::Qwen2VLTextModel;
use vision::Qwen2VLVisionModel;

use crate::{
amoe::AnyMoeBaseModelMixin,
device_map::DeviceMapper,
dummy_paged_attention::ModelConfigMetadata,
layers::CausalMasker,
layers_masker::PastKvLenCache,
paged_attention::AttentionImplementation,
pipeline::{
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel, NormalLoadingMetadata, VisionModel,
},
};

mod config;
mod text;
mod vision;

pub(crate) use config::Config;

pub struct Qwen2VLModel {
text: Qwen2VLTextModel,
vision: Qwen2VLVisionModel,
image_token_id: usize,
video_token_id: usize,
}

impl Qwen2VLModel {
pub fn new(
cfg: &Config,
vb: VarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if cfg.use_sliding_window {
// TODO!
candle_core::bail!("Sliding window is unsupported for now!");
}
let text = Qwen2VLTextModel::new(
cfg,
vb.clone(),
is_gptx,
normal_loading_metadata,
attention_mechanism,
)?;
let vision = Qwen2VLVisionModel::new(&cfg.vision_config, vb.pp("vision"))?;
Ok(Self {
text,
vision,
image_token_id: cfg.image_token_id,
video_token_id: cfg.video_token_id,
})
}

pub fn forward(
&self,
input_ids: &Tensor,
pixel_values: Option<Tensor>,
pixel_values_videos: Option<Tensor>,
image_grid_thw: Option<Tensor>,
video_grid_thw: Option<Tensor>,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let (input_embeds, attention_mask) = if pixel_values.is_some()
|| pixel_values_videos.is_some()
{
let mut xs = self.text.embed_tokens(input_ids)?;

let cache = self.text.cache.lock();
let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
self.text.cfg.sliding_window,
xs.dtype(),
self.text.cfg.num_attn_heads,
)?;

if let Some(pixel_values) = pixel_values {
let image_embeds = self.vision.forward(
&pixel_values,
&image_grid_thw.context("pixel_values require image_grid_thw")?,
)?;
let image_mask = input_ids
.eq(self.image_token_id as f64)?
.unsqueeze(D::Minus1)?
.broadcast_as(xs.shape())?;
xs = image_mask.where_cond(&image_embeds, &xs)?;
}

if let Some(pixel_values_videos) = pixel_values_videos {
let video_embeds = self.vision.forward(
&pixel_values_videos,
&video_grid_thw.context("pixel_values_videos require video_grid_thw")?,
)?;
let video_mask = input_ids
.eq(self.video_token_id as f64)?
.unsqueeze(D::Minus1)?
.broadcast_as(xs.shape())?;
xs = video_mask.where_cond(&video_embeds, &xs)?;
}

(xs, attention_mask)
} else {
let xs = self.text.embed_tokens(input_ids)?;
(xs, None)
};

self.text.forward_embeds(
input_embeds,
attention_mask.as_ref(),
seqlen_offsets,
context_lens,
metadata,
flash_params,
)
}
}

pub(crate) struct Qwen2VLVisionSpecificArgs {

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check failure on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Clippy

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

struct `Qwen2VLVisionSpecificArgs` is never constructed

Check warning on line 133 in mistralrs-core/src/vision_models/qwen2vl/mod.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

struct `Qwen2VLVisionSpecificArgs` is never constructed
image_grid_thw: Option<Tensor>,
video_grid_thw: Option<Tensor>,
pixel_values_video: Option<Tensor>,
}

impl VisionModel for Qwen2VLModel {
fn forward(
&self,
input_ids: &Tensor,
pixel_values: Option<Tensor>,
seqlen_offsets: &[usize],
_start_offsets_kernel: Tensor,
context_lens: Vec<(usize, usize)>,
_position_ids: Vec<usize>,
model_specific_args: Box<dyn Any>,
metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let Qwen2VLVisionSpecificArgs {
image_grid_thw,
video_grid_thw,
pixel_values_video,
} = *model_specific_args
.downcast()
.expect("Cannot downcast into `Qwen2VLVisionSpecificArgs`");
self.forward(
input_ids,
pixel_values,
pixel_values_video,
image_grid_thw,
video_grid_thw,
seqlen_offsets,
context_lens,
metadata,
flash_params,
)
}
fn cache(&self) -> &Cache {
&self.text.cache
}
fn device(&self) -> &Device {
&self.text.device
}
fn max_seq_len(&self) -> usize {
self.text.max_seq_len
}
fn has_conv2d(&self) -> bool {
true
}
fn config(&self) -> &ModelConfigMetadata {
&self.text.cfg
}
}

impl IsqModel for Qwen2VLModel {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
todo!()
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
todo!()
}
}

impl AnyMoeBaseModelMixin for Qwen2VLModel {}
Loading

0 comments on commit 2dffb9e

Please sign in to comment.