Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the OpenVLA model #572

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mistralrs-core/src/vision_models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub(crate) mod clip;
pub(crate) mod idefics2;
pub(crate) mod idefics2_input_processor;
pub(crate) mod image_processor;

pub(crate) mod llava;
pub(crate) mod phi3;
pub(crate) mod phi3_inputs_processor;
Expand All @@ -16,6 +15,7 @@ pub(crate) use llava::llava15;
pub(crate) use llava::llava_inputs_processor;
pub(crate) use llava::llava_next;
pub(crate) use llava::llava_next_inputs_processor;
pub(crate) mod openvla;

use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;

Expand Down
53 changes: 53 additions & 0 deletions mistralrs-core/src/vision_models/openvla.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use indexmap::IndexMap;
use serde::Deserialize;
use serde_json::Value;

use crate::serde_default_fn;

#[derive(Deserialize)]
enum LlmBackboneId {
#[serde(rename = "llama")]
Llama,
#[serde(rename = "mistral")]
Mistral,
#[serde(rename = "phi")]
Phi2,
}

#[derive(Deserialize)]
enum ResizeStrategy {
#[serde(rename = "resize-naive")]
ResizeNaive,
#[serde(rename = "resize-crop")]
Resize,
#[serde(rename = "letterbox")]
Letterbox,
}

serde_default_fn!(bool, output_proj_states, false);

#[derive(Deserialize)]
pub struct OpenVLAConfig {
// Removed as they are redundant/unused info:
// vision_backbone_id: String,
// llm_max_length: usize,

// Prismatic
llm_backbone_id: LlmBackboneId,

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

multiple fields are never read

Check failure on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Clippy

multiple fields are never read

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

multiple fields are never read

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Docs

multiple fields are never read

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

multiple fields are never read

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

multiple fields are never read

Check warning on line 36 in mistralrs-core/src/vision_models/openvla.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

multiple fields are never read
arch_specifier: String,
use_fused_vision_backbone: Option<bool>,
image_resize_strategy: ResizeStrategy,
text_config: Option<Value>,
pad_to_multiple_of: usize,
#[serde(default = "output_proj_states")]
output_projector_states: bool,
timm_model_ids: Vec<String>,
timm_override_act_layers: Vec<Option<String>>,
image_sizes: Vec<usize>,

// OpenVLA
n_action_bins: usize,
#[allow(clippy::type_complexity)]
norm_stats:
Option<IndexMap<String, IndexMap<String, IndexMap<String, IndexMap<String, Vec<f64>>>>>>,
}
45 changes: 45 additions & 0 deletions mistralrs-core/src/vision_models/preprocessor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
use image::imageops::FilterType;
use serde::Deserialize;

#[derive(Deserialize, Debug, Clone)]
pub(crate) struct VisionCropParams {
pub(crate) output_size: Vec<usize>,

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

field `output_size` is never read

Check failure on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Clippy

field `output_size` is never read

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

field `output_size` is never read

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Docs

field `output_size` is never read

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

field `output_size` is never read

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

field `output_size` is never read

Check warning on line 9 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

field `output_size` is never read
}

#[derive(Deserialize, Debug, Clone)]
pub(crate) struct VisionNormalizeParams {
pub(crate) inplace: bool,

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

fields `inplace`, `mean`, and `std` are never read

Check failure on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Clippy

fields `inplace`, `mean`, and `std` are never read

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

fields `inplace`, `mean`, and `std` are never read

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Docs

fields `inplace`, `mean`, and `std` are never read

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

fields `inplace`, `mean`, and `std` are never read

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

fields `inplace`, `mean`, and `std` are never read

Check warning on line 14 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

fields `inplace`, `mean`, and `std` are never read
pub(crate) mean: Option<[f64; 3]>,
pub(crate) std: Option<[f64; 3]>,
}

#[derive(Deserialize, Debug, Clone)]
pub(crate) struct VisionResizeParams {
pub(crate) antialias: bool,

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check failure on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Clippy

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Docs

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

fields `antialias`, `interpolation`, `max_size`, and `size` are never read

Check warning on line 21 in mistralrs-core/src/vision_models/preprocessor_config.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

fields `antialias`, `interpolation`, `max_size`, and `size` are never read
pub(crate) interpolation: usize,
pub(crate) max_size: Option<usize>,
pub(crate) size: (usize, usize),
}

#[derive(Deserialize, Debug, Clone)]
#[allow(dead_code)]
pub struct PreProcessorConfig {
Expand All @@ -22,6 +42,17 @@
pub(crate) crop_size: Option<HashMap<String, u32>>,
pub(crate) num_img_tokens: Option<usize>,
pub(crate) num_crops: Option<usize>,
// OpenVLA
pub(crate) means: Option<Vec<[f64; 3]>>,
pub(crate) stds: Option<Vec<[f64; 3]>>,
pub(crate) input_sizes: Option<Vec<[usize; 3]>>,
pub(crate) tvf_crop_params: Option<Vec<VisionCropParams>>,
pub(crate) tvf_do_letterbox: Option<bool>,
pub(crate) tvf_letterbox_fill: Option<(usize, usize, usize)>,
pub(crate) tvf_normalize_params: Option<Vec<VisionNormalizeParams>>,
pub(crate) tvf_resize_params: Option<Vec<VisionResizeParams>>,
pub(crate) use_fused_vision_backbone: Option<bool>,
pub(crate) interpolations: Option<Vec<String>>,
}

#[allow(dead_code)]
Expand All @@ -43,3 +74,17 @@
}
}
}

impl ToFilter for String {
// https://github.com/python-pillow/Pillow/blob/4b68563e8a818fb9c528fa159ddf3f4eaefa35e6/src/PIL/Image.py#L164-L170
// Default: https://github.com/huggingface/transformers/blob/0df888ffb72ea370555efdef45985378d3cc7b2b/src/transformers/models/idefics2/image_processing_idefics2.py#L226
fn to_filter(self) -> Result<FilterType> {
match self.to_lowercase().as_str() {
"nearest" => Ok(FilterType::Nearest),
"lanczos" => Ok(FilterType::Lanczos3),
"bilinear" => Ok(FilterType::Triangle), // BiLinear
"bicubic" => Ok(FilterType::CatmullRom), // BiCubic
x => candle_core::bail!("Filter {x} not supported"),
}
}
}
Loading