Skip to content

Commit

Permalink
fix typings in patchers
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 18, 2024
1 parent f791b94 commit 0eb94f5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
10 changes: 4 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@

import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
from optimum.intel.utils.import_utils import (
_openvino_version,
Expand Down Expand Up @@ -3385,8 +3383,8 @@ def __exit__(self, exc_type, exc_value, traceback):
class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: OnnxConfig,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward
Expand Down Expand Up @@ -3424,8 +3422,8 @@ def __exit__(self, exc_type, exc_value, traceback):
class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
def __init__(
self,
config: OnnxConfig,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,7 @@ def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs):
pixel_values=hidden_states, attention_mask=causal_mask, rotary_pos_emb=rotary_pos_emb
)[0]
return res

# Adopted from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1089
# Use config values instead of model attributes, replace self.rotary_pos_emb -> self._rotary_pos_emb
def rot_pos_emb(self, grid_thw):
Expand Down

0 comments on commit 0eb94f5

Please sign in to comment.