Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 17, 2024
1 parent bf44a19 commit 2f92e6e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 37 deletions.
44 changes: 9 additions & 35 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,37 +2300,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return generated_input


class DummyQwen2VLVisionEMbedInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states",)

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = 1,
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = 420,
height: int = 420,
**kwargs,
):
self.batch_size = batch_size
self.height = height
self.width = width
self.num_channels = num_channels
self.temporal_patch_size = normalized_config.config.temporal_patch_size
self.patch_size = normalized_config.config.patch_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
grid_t = self.batch_size
shape = [
grid_t * grid_h * grid_w,
self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size,
]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


class DummyQwen2VLVisionEmbedMergerInputGenerator(DummyVisionInputGenerator):
class DummyQwen2VLVisionEmbedInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb")

def __init__(
Expand All @@ -2349,7 +2319,10 @@ def __init__(
self.num_channels = num_channels
self.temporal_patch_size = normalized_config.config.temporal_patch_size
self.patch_size = normalized_config.config.patch_size
self.embed_dim = normalized_config.config.embed_dim
if normalized_config.use_embed_dim:
self.embed_dim = normalized_config.config.embed_dim
else:
self.embed_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
self.num_heads = normalized_config.config.num_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down Expand Up @@ -2382,7 +2355,7 @@ class Qwen2VLConfigBehavior(str, enum.Enum):
class Qwen2VLOpenVINOConfig(OnnxConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedInputGenerator,)
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")

def __init__(
Expand All @@ -2405,12 +2378,13 @@ def __init__(
self._orig_config = config
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
self._normalized_config.use_embed_dim = False
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator,)
self._normalized_config.use_embed_dim = True

@staticmethod
def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]):
Expand Down
5 changes: 3 additions & 2 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
from ...exporters.openvino.utils import save_config
from .. import OVQuantizer
from ..utils.import_utils import is_transformers_version
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
Expand Down Expand Up @@ -2096,13 +2097,13 @@ def __init__(
quantization_config=quantization_config,
**kwargs,
)
try:
if is_transformers_version(">=", "4.45.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding

self._rotary_pos_emb = VisionRotaryEmbedding(
self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2
)
except ImportError:
else:
raise ValueError(
f"Initialization model for {self.config.model_type} required at least transformers >= 4.45"
)
Expand Down

0 comments on commit 2f92e6e

Please sign in to comment.