Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2vl support #1042

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
12 changes: 8 additions & 4 deletions notebooks/openvino/sentence_transformer_quantization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@
],
"source": [
"from functools import partial\n",
"import datasets\n",
"\n",
eaidova marked this conversation as resolved.
Show resolved Hide resolved
"from transformers import AutoTokenizer\n",
"from optimum.intel import OVModelForFeatureExtraction, OVQuantizer, OVQuantizationConfig, OVConfig\n",
"\n",
"from optimum.intel import OVConfig, OVModelForFeatureExtraction, OVQuantizationConfig, OVQuantizer\n",
"\n",
"\n",
"MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
"base_model_path = \"all-MiniLM-L6-v2\"\n",
Expand All @@ -187,6 +189,7 @@
"\n",
"quantizer = OVQuantizer.from_pretrained(model)\n",
"\n",
"\n",
"def preprocess_function(examples, tokenizer):\n",
" return tokenizer(examples[\"sentence\"], padding=\"max_length\", max_length=384, truncation=True)\n",
"\n",
Expand Down Expand Up @@ -225,9 +228,9 @@
"metadata": {},
"outputs": [],
"source": [
"from transformers import Pipeline\n",
"import torch.nn.functional as F\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from transformers import Pipeline\n",
"\n",
"\n",
"# copied from the model card \"sentence-transformers/all-MiniLM-L6-v2\"\n",
Expand Down Expand Up @@ -296,6 +299,7 @@
"from datasets import load_dataset\n",
"from evaluate import load\n",
"\n",
"\n",
"eval_dataset = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
"metric = load(\"glue\", \"stsb\")"
]
Expand Down
258 changes: 252 additions & 6 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
PersimmonModelPatcher,
Phi3ModelPatcher,
Phi3VisionImageEmbeddingsPatcher,
Qwen2VLLanguageModelPatcher,
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
UpdateCausalMaskModelPatcher,
Expand All @@ -106,6 +108,10 @@ def init_model_configs():
"transformers",
"LlavaNextForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "qwen2-vl", "image-text-to-text")] = (
"transformers",
"Qwen2VLForConditionalGeneration",
)
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
"image-text-to-text"
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
Expand Down Expand Up @@ -1288,18 +1294,26 @@ def patch_model_for_export(


class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
def __init__(self, export_config):
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None):
self.orig_export_config = export_config
if dummy_input_generator is not None:
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
dummy_input_generator,
) + export_config.DUMMY_INPUT_GENERATOR_CLASSES
self.DUMMY_INPUT_GENERATOR_CLASSES = export_config.DUMMY_INPUT_GENERATOR_CLASSES
self.DEFAULT_ONNX_OPSET = export_config.DEFAULT_ONNX_OPSET
self.DUMMY_PKV_GENERATOR_CLASS = export_config.DUMMY_PKV_GENERATOR_CLASS
self._config = export_config._config
self._normalized_config = export_config._normalized_config
self.use_past = export_config.use_past
self.patcher_cls = patcher_cls
self.input_info_upd = inputs_update

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
if self.patcher_cls is not None:
return self.patcher_cls(self, model, model_kwargs=model_kwargs)
# Refer to DecoderModelPatcher.
return self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)

Expand All @@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
orig_inputs = self.orig_export_config.inputs
input_ids_config = orig_inputs.pop("input_ids")
orig_inputs["inputs_embeds"] = input_ids_config
if self.input_info_upd is not None:
orig_inputs.update(self.input_info_upd)
return orig_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
Expand Down Expand Up @@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
return export_config


def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype):
def get_vlm_text_generation_config(
model_type,
model_config,
int_dtype,
float_dtype,
model_patcher=None,
dummy_input_generator=None,
inputs_update=None,
):
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
export_config = LMInputEmbedsConfigHelper(internal_export_config)
export_config = LMInputEmbedsConfigHelper(
internal_export_config,
patcher_cls=model_patcher,
dummy_input_generator=dummy_input_generator,
inputs_update=inputs_update,
)
export_config._normalized_config = internal_export_config._normalized_config
return export_config

Expand Down Expand Up @@ -1820,9 +1849,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
img_ids_height = self.height // 2
img_ids_width = self.width // 2
return self.random_int_tensor(
[self.batch_size, img_ids_height * img_ids_width, 3]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3],
(
[self.batch_size, img_ids_height * img_ids_width, 3]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3]
),
min_value=0,
max_value=min(img_ids_height, img_ids_width),
framework=framework,
Expand Down Expand Up @@ -2259,3 +2290,218 @@ def patch_model_for_export(
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


class DummyQwen2VLLMInputGenerator(DummyTextInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
generated_input = super().generate(input_name, framework, int_dtype, float_dtype)
if input_name == "position_ids":
return generated_input.unsqueeze(0).expand(3, -1, -1)
return generated_input


class DummyQwen2VLVisionEMbedInputGenerator(DummyVisionInputGenerator):
eaidova marked this conversation as resolved.
Show resolved Hide resolved
SUPPORTED_INPUT_NAMES = ("hidden_states",)

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = 1,
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = 420,
height: int = 420,
**kwargs,
):
self.batch_size = batch_size
self.height = height
self.width = width
self.num_channels = num_channels
self.temporal_patch_size = normalized_config.config.temporal_patch_size
self.patch_size = normalized_config.config.patch_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
grid_t = self.batch_size
shape = [
grid_t * grid_h * grid_w,
self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size,
]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
eaidova marked this conversation as resolved.
Show resolved Hide resolved


class DummyQwen2VLVisionEmbedMergerInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb")

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = 1,
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = 420,
height: int = 420,
**kwargs,
):
self.batch_size = batch_size
self.height = height
self.width = width
self.num_channels = num_channels
self.temporal_patch_size = normalized_config.config.temporal_patch_size
eaidova marked this conversation as resolved.
Show resolved Hide resolved
self.patch_size = normalized_config.config.patch_size
self.embed_dim = normalized_config.config.embed_dim
self.num_heads = normalized_config.config.num_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
grid_t = self.batch_size

if input_name == "hidden_states":
return self.random_float_tensor(
[grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype
)

if input_name == "attention_mask":
return self.random_mask_tensor(
[1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype
)

if input_name == "rotary_pos_emb":
dim = self.embed_dim // self.num_heads // 2
return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype)


class Qwen2VLConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
VISION_EMBEDDINGS = "vision_embeddings"
VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
TEXT_EMBEDDINGS = "text_embeddings"


@register_in_tasks_manager("qwen2-vl", *["image-text-to-text"], library_name="transformers")
class Qwen2VLOpenVINOConfig(OnnxConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self._behavior = behavior
self._orig_config = config
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator,)

@staticmethod
def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
behavior = Qwen2VLConfigBehavior(behavior)

if behavior == Qwen2VLConfigBehavior.LANGUAGE:
return model

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
vision_embeddings = model.visual.patch_embed
vision_embeddings.config = model.config.vision_config
return vision_embeddings

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
vision_emb_merger = model.visual
vision_emb_merger.config = model.config.vision_config
return vision_emb_merger

if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.model.embed_tokens
text_embedding.config = model.config
return text_embedding

def with_behavior(
self,
behavior: Union[str, Qwen2VLConfigBehavior],
):
"""
Creates a config for different behaviour.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
"""
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
behavior = Qwen2VLConfigBehavior(behavior)

if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)

if behavior == Qwen2VLConfigBehavior.LANGUAGE:
return get_vlm_text_generation_config(
"qwen2",
self._orig_config,
self.int_dtype,
self.float_dtype,
model_patcher=Qwen2VLLanguageModelPatcher,
dummy_input_generator=DummyQwen2VLLMInputGenerator,
inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}},
)

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return Qwen2VLVisionEmbMergerPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}}
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return {
"hidden_states": {0: "sequence_length"},
"attention_mask": {1: "sequence_length", 2: "sequence_length"},
"rotary_pos_emb": {0: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]:
return {"last_hidden_state": {0: "seq_len"}}
return {}
Loading