diff --git a/mistralrs-core/src/vision_models/mod.rs b/mistralrs-core/src/vision_models/mod.rs index 903a22cfe..481b4a8a6 100644 --- a/mistralrs-core/src/vision_models/mod.rs +++ b/mistralrs-core/src/vision_models/mod.rs @@ -6,7 +6,6 @@ pub(crate) mod clip; pub(crate) mod idefics2; pub(crate) mod idefics2_input_processor; pub(crate) mod image_processor; - pub(crate) mod llava; pub(crate) mod phi3; pub(crate) mod phi3_inputs_processor; @@ -16,6 +15,7 @@ pub(crate) use llava::llava15; pub(crate) use llava::llava_inputs_processor; pub(crate) use llava::llava_next; pub(crate) use llava::llava_next_inputs_processor; +pub(crate) mod openvla; use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata; diff --git a/mistralrs-core/src/vision_models/openvla.rs b/mistralrs-core/src/vision_models/openvla.rs new file mode 100644 index 000000000..06b719cf1 --- /dev/null +++ b/mistralrs-core/src/vision_models/openvla.rs @@ -0,0 +1,53 @@ +use indexmap::IndexMap; +use serde::Deserialize; +use serde_json::Value; + +use crate::serde_default_fn; + +#[derive(Deserialize)] +enum LlmBackboneId { + #[serde(rename = "llama")] + Llama, + #[serde(rename = "mistral")] + Mistral, + #[serde(rename = "phi")] + Phi2, +} + +#[derive(Deserialize)] +enum ResizeStrategy { + #[serde(rename = "resize-naive")] + ResizeNaive, + #[serde(rename = "resize-crop")] + Resize, + #[serde(rename = "letterbox")] + Letterbox, +} + +serde_default_fn!(bool, output_proj_states, false); + +#[derive(Deserialize)] +pub struct OpenVLAConfig { + // Removed as they are redundant/unused info: + // vision_backbone_id: String, + // llm_max_length: usize, + + // Prismatic + llm_backbone_id: LlmBackboneId, + arch_specifier: String, + use_fused_vision_backbone: Option, + image_resize_strategy: ResizeStrategy, + text_config: Option, + pad_to_multiple_of: usize, + #[serde(default = "output_proj_states")] + output_projector_states: bool, + timm_model_ids: Vec, + timm_override_act_layers: Vec>, + image_sizes: Vec, + + // OpenVLA + n_action_bins: usize, + #[allow(clippy::type_complexity)] + norm_stats: + Option>>>>>, +} diff --git a/mistralrs-core/src/vision_models/preprocessor_config.rs b/mistralrs-core/src/vision_models/preprocessor_config.rs index c4a6b7cd5..b1711520d 100644 --- a/mistralrs-core/src/vision_models/preprocessor_config.rs +++ b/mistralrs-core/src/vision_models/preprocessor_config.rs @@ -4,6 +4,26 @@ use candle_core::Result; use image::imageops::FilterType; use serde::Deserialize; +#[derive(Deserialize, Debug, Clone)] +pub(crate) struct VisionCropParams { + pub(crate) output_size: Vec, +} + +#[derive(Deserialize, Debug, Clone)] +pub(crate) struct VisionNormalizeParams { + pub(crate) inplace: bool, + pub(crate) mean: Option<[f64; 3]>, + pub(crate) std: Option<[f64; 3]>, +} + +#[derive(Deserialize, Debug, Clone)] +pub(crate) struct VisionResizeParams { + pub(crate) antialias: bool, + pub(crate) interpolation: usize, + pub(crate) max_size: Option, + pub(crate) size: (usize, usize), +} + #[derive(Deserialize, Debug, Clone)] #[allow(dead_code)] pub struct PreProcessorConfig { @@ -22,6 +42,17 @@ pub struct PreProcessorConfig { pub(crate) crop_size: Option>, pub(crate) num_img_tokens: Option, pub(crate) num_crops: Option, + // OpenVLA + pub(crate) means: Option>, + pub(crate) stds: Option>, + pub(crate) input_sizes: Option>, + pub(crate) tvf_crop_params: Option>, + pub(crate) tvf_do_letterbox: Option, + pub(crate) tvf_letterbox_fill: Option<(usize, usize, usize)>, + pub(crate) tvf_normalize_params: Option>, + pub(crate) tvf_resize_params: Option>, + pub(crate) use_fused_vision_backbone: Option, + pub(crate) interpolations: Option>, } #[allow(dead_code)] @@ -43,3 +74,17 @@ impl ToFilter for Option { } } } + +impl ToFilter for String { + // https://github.com/python-pillow/Pillow/blob/4b68563e8a818fb9c528fa159ddf3f4eaefa35e6/src/PIL/Image.py#L164-L170 + // Default: https://github.com/huggingface/transformers/blob/0df888ffb72ea370555efdef45985378d3cc7b2b/src/transformers/models/idefics2/image_processing_idefics2.py#L226 + fn to_filter(self) -> Result { + match self.to_lowercase().as_str() { + "nearest" => Ok(FilterType::Nearest), + "lanczos" => Ok(FilterType::Lanczos3), + "bilinear" => Ok(FilterType::Triangle), // BiLinear + "bicubic" => Ok(FilterType::CatmullRom), // BiCubic + x => candle_core::bail!("Filter {x} not supported"), + } + } +}