Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 4, 2024
1 parent 7862263 commit 4b72b53
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 57 deletions.
5 changes: 1 addition & 4 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,10 @@ def main_export(

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default thta is not available for cpu
logger.warn(model_type)
if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES:
loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type]

logger.warn(loading_kwargs)
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,16 @@ def export_from_model(

model_name_or_path = model.config._name_or_path
if preprocessors is not None:
# phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk
if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1:
if not hasattr(preprocessors[1], "chat_template"):
preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None)
for processor in preprocessors:
try:
processor.save_pretrained(output)
except Exception as ex:
logger.error(f"Saving {type(processor)} failed with {ex}")
else:
else:
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
Expand Down Expand Up @@ -848,7 +852,7 @@ def _get_multi_modal_submodels_and_export_configs(

if model_type == "internvl-chat" and preprocessors is not None:
model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("<IMG_CONTEXT>")

if model_type == "phi3-v":
model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist()
model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist()
Expand Down
25 changes: 13 additions & 12 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from packaging import version
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, AutoConfig
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
Expand Down Expand Up @@ -1749,7 +1749,8 @@ class Phi3VisionConfigBehavior(str, enum.Enum):


class DummyPhi3VisionProjectionInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("input", )
SUPPORTED_INPUT_NAMES = ("input",)

def __init__(
self,
task: str,
Expand All @@ -1762,10 +1763,10 @@ def __init__(
):
self.batch_size = batch_size
self._embed_layer_realization = normalized_config.config.embd_layer["embedding_cls"]
self.image_dim_out = normalized_config.config.img_processor['image_dim_out']
self.image_dim_out = normalized_config.config.img_processor["image_dim_out"]
self.height = height
self.width = width

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
h = self.height // 336
w = self.width // 336
Expand All @@ -1777,7 +1778,6 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)



@register_in_tasks_manager("phi3-v", *["image-text-to-text"], library_name="transformers")
class Phi3VisionOpenVINOConfig(OnnxConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Phi3VisionConfigBehavior]
Expand All @@ -1804,14 +1804,15 @@ def __init__(
self._behavior = behavior
self._orig_config = config
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "img_processor"):
self._config = AutoConfig.from_pretrained(config.img_processor["model_name"], trust_remote_code=True).vision_config
self._config = AutoConfig.from_pretrained(
config.img_processor["model_name"], trust_remote_code=True
).vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
if self._behavior == Phi3VisionConfigBehavior.VISION_PROJECTION and hasattr(config, "img_processor"):
self._config = config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyPhi3VisionProjectionInputGenerator, )

self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyPhi3VisionProjectionInputGenerator,)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1823,7 +1824,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior in [Phi3VisionConfigBehavior.VISION_EMBEDDINGS, Phi3VisionConfigBehavior.VISION_PROJECTION]:
return {"last_hidden_state": {0: "batch_size", 1: "height_width_projection"}}
return {"last_hidden_state": {0: "batch_size", 1: "height_width_projection"}}
return {}

def with_behavior(
Expand Down Expand Up @@ -1928,8 +1929,8 @@ def get_model_for_behavior(self, model, behavior: Union[str, Phi3VisionConfigBeh
vision_embeddings = model.model.vision_embed_tokens
vision_embeddings.config = model.config
return vision_embeddings
if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION:

if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION:
projection = model.model.vision_embed_tokens.img_projection
projection.config = model.config
return projection
Expand All @@ -1945,4 +1946,4 @@ def patch_model_for_export(
model_kwargs = model_kwargs or {}
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)
5 changes: 3 additions & 2 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ def phi3_442_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -2769,6 +2769,7 @@ def __exit__(self, exc_type, exc_value, traceback):
def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor):
return self.get_img_features(pixel_values)


class Phi3VisionImageEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
Expand All @@ -2782,4 +2783,4 @@ def __init__(

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
self._model.forward = self._model.__orig_forward
57 changes: 31 additions & 26 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def forward(self, img_features):
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {
"vision_projection": OVVisionProjection
}
MODEL_PARTS_CLS_MAPPING = {"vision_projection": OVVisionProjection}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
Expand Down Expand Up @@ -522,7 +520,7 @@ def _from_transformers(
ov_config=ov_config,
stateful=stateful,
)
config = AutoConfig.from_pretrained(save_dir_path)
config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code)
return cls._from_pretrained(
model_id=save_dir_path,
config=config,
Expand Down Expand Up @@ -1148,13 +1146,26 @@ def __init__(
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(language_model, text_embeddings, vision_embeddings, config, device, dynamic_shapes, ov_config, model_save_dir, quantization_config, **kwargs)
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)

def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs):
num_images, num_crops, c, h, w = pixel_values.shape
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(num_images, num_crops, -1, self.config.img_processor['image_dim_out'])
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(
num_images, num_crops, -1, self.config.img_processor["image_dim_out"]
)
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj

Expand All @@ -1181,9 +1192,7 @@ def hd_feature_transform(self, image_features, image_sizes):
# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1 : 1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
sub_image_features, h_crop, w_crop
)
sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)

# [sub features, separator, global features]
Expand All @@ -1194,9 +1203,7 @@ def hd_feature_transform(self, image_features, image_sizes):
global_image_features_hd_newline[i],
]
)
image_features_proj = self.vision_projection(
torch.cat(all_image_embeddings, dim=0).unsqueeze(0)
)[0]
image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0]

return image_features_proj

Expand All @@ -1214,13 +1221,9 @@ def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(
num_images, h_crop, w_crop, H // 2, H // 2, -1
) # n_img, h_crop, w_crop, 12, 12, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(
num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
) # n_img, h_crop*12, w_crop*12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
)

return image_features_hd
Expand All @@ -1233,13 +1236,13 @@ def add_image_newline(self, image_features_hd):
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat(
[image_features_hd, newline_embeddings], dim=2
).reshape(num_images, -1, hid_dim)
image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape(
num_images, -1, hid_dim
)
return image_features_hd_newline

def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs
):
MAX_INPUT_ID = int(1e9)
input_shape = input_ids.size()
Expand All @@ -1251,16 +1254,18 @@ def get_multimodal_embeddings(
input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size)
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs))
if has_image:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs)
vision_embeds = self.get_vision_embeddings(
pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs
)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)

return inputs_embeds, attention_mask, position_ids


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM
"phi3_v": _OVPhi3VisionForCausalLM,
}
39 changes: 28 additions & 11 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
Pix2StructForConditionalGeneration,
Expand Down Expand Up @@ -1880,9 +1881,11 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
]

if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next"]
SUPPORTED_ARCHITECTURES += ["llava_next", "phi3_v"]
TASK = "image-text-to-text"

REMOTE_CODE_MODELS = ["phi3_v"]

IMAGE = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg",
Expand All @@ -1899,19 +1902,27 @@ def get_transformer_model_class(self, model_arch):
from transformers import LlavaNextForConditionalGeneration

return LlavaNextForConditionalGeneration
return None
return AutoModelForCausalLM

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
prompt = "<image>\n What is shown in this image?"
prompt = (
"<image>\n What is shown in this image?"
if not "phi3_v" in model_arch
else "<|user|>\n<|image_1|>\nWhat is shown in this image?<|end|>\n<|assistant|>\n"
)
model_id = MODEL_NAMES[model_arch]
processor = get_preprocessor(model_id)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True)
ov_model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
self.assertIsInstance(ov_model, MODEL_TYPE_TO_CLS_MAPPING[ov_model.config.model_type])
self.assertIsInstance(ov_model.vision_embeddings, OVVisionEmbedding)
self.assertIsInstance(ov_model.language_model, OVModelWithEmbedForCausalLM)
Expand Down Expand Up @@ -1950,20 +1961,26 @@ def test_compare_to_transformers(self, model_arch):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True)
preprocessor = get_preprocessor(model_id)
question = "<image>\nDescribe image"
model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
preprocessor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
question = (
"<image>\nDescribe image"
if not "phi3_v" in model_arch
else "<|user|>\n<|image_1|>\nWhat is shown in this image?<|end|>\n<|assistant|>\n"
)
inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt")

# General case
outputs = model.generate(**inputs, max_new_tokens=10)
outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True)
outputs = preprocessor.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
self.assertIsInstance(outputs[0], str)

question = "Hi, how are you?"
inputs = preprocessor(images=None, text=question, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=10)
outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True)
outputs = preprocessor.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
self.assertIsInstance(outputs[0], str)
del model

Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"phi3_v": "katuni4ka/tiny-random-phi3-vision",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "fxmarty/tiny-dummy-qwen2",
Expand Down

0 comments on commit 4b72b53

Please sign in to comment.