From 216da6a61784b0217543fb3e0fb54bc0d3a8c2e1 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 4 Nov 2024 21:09:01 +0400 Subject: [PATCH] add test --- optimum/exporters/openvino/__main__.py | 5 +- optimum/exporters/openvino/convert.py | 8 ++- optimum/exporters/openvino/model_configs.py | 23 ++++---- optimum/exporters/openvino/model_patcher.py | 5 +- .../openvino/modeling_visual_language.py | 57 ++++++++++--------- tests/openvino/test_modeling.py | 39 +++++++++---- tests/openvino/utils_tests.py | 1 + 7 files changed, 82 insertions(+), 56 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index d630936743..09b820b3b5 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -266,13 +266,10 @@ def main_export( if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: loading_kwargs["attn_implementation"] = "eager" - + # some models force flash_attn attention by default thta is not available for cpu - logger.warn(model_type) if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] - - logger.warn(loading_kwargs) # there are some difference between remote and in library representation of past key values for some models, # for avoiding confusion we disable remote code for them if ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 8d4605936c..2c08632add 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -682,12 +682,16 @@ def export_from_model( model_name_or_path = model.config._name_or_path if preprocessors is not None: + # phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk + if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1: + if not hasattr(preprocessors[1], "chat_template"): + preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None) for processor in preprocessors: try: processor.save_pretrained(output) except Exception as ex: logger.error(f"Saving {type(processor)} failed with {ex}") - else: + else: maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()] @@ -848,7 +852,7 @@ def _get_multi_modal_submodels_and_export_configs( if model_type == "internvl-chat" and preprocessors is not None: model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("") - + if model_type == "phi3-v": model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist() model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist() diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 40bd9f5bbe..921cf17d84 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -1749,7 +1749,8 @@ class Phi3VisionConfigBehavior(str, enum.Enum): class DummyPhi3VisionProjectionInputGenerator(DummyVisionInputGenerator): - SUPPORTED_INPUT_NAMES = ("input", ) + SUPPORTED_INPUT_NAMES = ("input",) + def __init__( self, task: str, @@ -1762,10 +1763,10 @@ def __init__( ): self.batch_size = batch_size self._embed_layer_realization = normalized_config.config.embd_layer["embedding_cls"] - self.image_dim_out = normalized_config.config.img_processor['image_dim_out'] + self.image_dim_out = normalized_config.config.img_processor["image_dim_out"] self.height = height self.width = width - + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): h = self.height // 336 w = self.width // 336 @@ -1777,7 +1778,6 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) - @register_in_tasks_manager("phi3-v", *["image-text-to-text"], library_name="transformers") class Phi3VisionOpenVINOConfig(OnnxConfig): SUPPORTED_BEHAVIORS = [model_type.value for model_type in Phi3VisionConfigBehavior] @@ -1804,14 +1804,15 @@ def __init__( self._behavior = behavior self._orig_config = config if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "img_processor"): - self._config = AutoConfig.from_pretrained(config.img_processor["model_name"], trust_remote_code=True).vision_config + self._config = AutoConfig.from_pretrained( + config.img_processor["model_name"], trust_remote_code=True + ).vision_config self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) if self._behavior == Phi3VisionConfigBehavior.VISION_PROJECTION and hasattr(config, "img_processor"): self._config = config self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) - self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyPhi3VisionProjectionInputGenerator, ) - + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyPhi3VisionProjectionInputGenerator,) @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1823,7 +1824,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: if self._behavior in [Phi3VisionConfigBehavior.VISION_EMBEDDINGS, Phi3VisionConfigBehavior.VISION_PROJECTION]: - return {"last_hidden_state": {0: "batch_size", 1: "height_width_projection"}} + return {"last_hidden_state": {0: "batch_size", 1: "height_width_projection"}} return {} def with_behavior( @@ -1928,8 +1929,8 @@ def get_model_for_behavior(self, model, behavior: Union[str, Phi3VisionConfigBeh vision_embeddings = model.model.vision_embed_tokens vision_embeddings.config = model.config return vision_embeddings - - if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION: + + if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION: projection = model.model.vision_embed_tokens.img_projection projection.config = model.config return projection @@ -1945,4 +1946,4 @@ def patch_model_for_export( model_kwargs = model_kwargs or {} if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs) - return super().patch_model_for_export(model, model_kwargs) \ No newline at end of file + return super().patch_model_for_export(model, model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 73412747f4..ca716d5fce 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1362,7 +1362,7 @@ def phi3_442_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -2769,6 +2769,7 @@ def __exit__(self, exc_type, exc_value, traceback): def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor): return self.get_img_features(pixel_values) + class Phi3VisionImageEmbeddingsPatcher(ModelPatcher): def __init__( self, @@ -2782,4 +2783,4 @@ def __init__( def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - self._model.forward = self._model.__orig_forward \ No newline at end of file + self._model.forward = self._model.__orig_forward diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index ed5bb5ed17..eea9fd5e97 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -200,9 +200,7 @@ def forward(self, img_features): return self.request(img_features)[0] -MODEL_PARTS_CLS_MAPPING = { - "vision_projection": OVVisionProjection -} +MODEL_PARTS_CLS_MAPPING = {"vision_projection": OVVisionProjection} class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin): @@ -522,7 +520,7 @@ def _from_transformers( ov_config=ov_config, stateful=stateful, ) - config = AutoConfig.from_pretrained(save_dir_path) + config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code) return cls._from_pretrained( model_id=save_dir_path, config=config, @@ -1148,13 +1146,26 @@ def __init__( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): - super().__init__(language_model, text_embeddings, vision_embeddings, config, device, dynamic_shapes, ov_config, model_save_dir, quantization_config, **kwargs) + super().__init__( + language_model, + text_embeddings, + vision_embeddings, + config, + device, + dynamic_shapes, + ov_config, + model_save_dir, + quantization_config, + **kwargs, + ) self.sub_GN = torch.tensor(self.config.sub_GN) self.glb_GN = torch.tensor(self.config.glb_GN) def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs): num_images, num_crops, c, h, w = pixel_values.shape - img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(num_images, num_crops, -1, self.config.img_processor['image_dim_out']) + img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape( + num_images, num_crops, -1, self.config.img_processor["image_dim_out"] + ) image_features_proj = self.hd_feature_transform(img_features, image_sizes) return image_features_proj @@ -1181,9 +1192,7 @@ def hd_feature_transform(self, image_features, image_sizes): # NOTE: real num_crops is padded # (num_crops, 24*24, 1024) sub_image_features = image_features[i, 1 : 1 + num_crops] - sub_image_features_hd = self.reshape_hd_patches_2x2merge( - sub_image_features, h_crop, w_crop - ) + sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop) sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd) # [sub features, separator, global features] @@ -1194,9 +1203,7 @@ def hd_feature_transform(self, image_features, image_sizes): global_image_features_hd_newline[i], ] ) - image_features_proj = self.vision_projection( - torch.cat(all_image_embeddings, dim=0).unsqueeze(0) - )[0] + image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0] return image_features_proj @@ -1214,13 +1221,9 @@ def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 .reshape(N, -1, 4 * C) # N, 144, 4096 - .reshape( - num_images, h_crop, w_crop, H // 2, H // 2, -1 - ) # n_img, h_crop, w_crop, 12, 12, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 - .reshape( - num_images, h_crop * H // 2, w_crop * H // 2, 4 * C - ) # n_img, h_crop*12, w_crop*12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 ) return image_features_hd @@ -1233,13 +1236,13 @@ def add_image_newline(self, image_features_hd): num_images, h, w, hid_dim = image_features_hd.shape # add the newline token to the HD image feature patches newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) - image_features_hd_newline = torch.cat( - [image_features_hd, newline_embeddings], dim=2 - ).reshape(num_images, -1, hid_dim) + image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape( + num_images, -1, hid_dim + ) return image_features_hd_newline def get_multimodal_embeddings( - self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs + self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs ): MAX_INPUT_ID = int(1e9) input_shape = input_ids.size() @@ -1251,16 +1254,18 @@ def get_multimodal_embeddings( input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size) inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs)) if has_image: - vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs) + vision_embeds = self.get_vision_embeddings( + pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs + ) image_features_proj = torch.from_numpy(vision_embeds) inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False) - + return inputs_embeds, attention_mask, position_ids - + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, "internvl_chat": _OvInternVLForCausalLM, - "phi3_v": _OVPhi3VisionForCausalLM + "phi3_v": _OVPhi3VisionForCausalLM, } diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 082ffef285..596809e562 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -50,6 +50,7 @@ AutoModelForSpeechSeq2Seq, AutoModelForTokenClassification, AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, GenerationConfig, Pix2StructForConditionalGeneration, @@ -1880,9 +1881,11 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase): ] if is_transformers_version(">=", "4.40.0"): - SUPPORTED_ARCHITECTURES += ["llava_next"] + SUPPORTED_ARCHITECTURES += ["llava_next", "phi3_v"] TASK = "image-text-to-text" + REMOTE_CODE_MODELS = ["phi3_v"] + IMAGE = Image.open( requests.get( "http://images.cocodataset.org/val2017/000000039769.jpg", @@ -1899,19 +1902,27 @@ def get_transformer_model_class(self, model_arch): from transformers import LlavaNextForConditionalGeneration return LlavaNextForConditionalGeneration - return None + return AutoModelForCausalLM @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): - prompt = "\n What is shown in this image?" + prompt = ( + "\n What is shown in this image?" + if not "phi3_v" in model_arch + else "<|user|>\n<|image_1|>\nWhat is shown in this image?<|end|>\n<|assistant|>\n" + ) model_id = MODEL_NAMES[model_arch] - processor = get_preprocessor(model_id) - transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + transformers_model = self.get_transformer_model_class(model_arch).from_pretrained( + model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS + ) inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt") set_seed(SEED) with torch.no_grad(): transformers_outputs = transformers_model(**inputs) - ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True) + ov_model = OVModelForVisualCausalLM.from_pretrained( + model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS + ) self.assertIsInstance(ov_model, MODEL_TYPE_TO_CLS_MAPPING[ov_model.config.model_type]) self.assertIsInstance(ov_model.vision_embeddings, OVVisionEmbedding) self.assertIsInstance(ov_model.language_model, OVModelWithEmbedForCausalLM) @@ -1950,20 +1961,26 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_generate_utils(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True) - preprocessor = get_preprocessor(model_id) - question = "\nDescribe image" + model = OVModelForVisualCausalLM.from_pretrained( + model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS + ) + preprocessor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + question = ( + "\nDescribe image" + if not "phi3_v" in model_arch + else "<|user|>\n<|image_1|>\nWhat is shown in this image?<|end|>\n<|assistant|>\n" + ) inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt") # General case outputs = model.generate(**inputs, max_new_tokens=10) - outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True) + outputs = preprocessor.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) self.assertIsInstance(outputs[0], str) question = "Hi, how are you?" inputs = preprocessor(images=None, text=question, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=10) - outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True) + outputs = preprocessor.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) self.assertIsInstance(outputs[0], str) del model diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index e5a9f73a64..80cfc67b20 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -104,6 +104,7 @@ "pix2struct": "fxmarty/pix2struct-tiny-random", "phi": "echarlaix/tiny-random-PhiForCausalLM", "phi3": "Xenova/tiny-random-Phi3ForCausalLM", + "phi3_v": "katuni4ka/tiny-random-phi3-vision", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "fxmarty/tiny-dummy-qwen2",