Skip to content

Commit

Permalink
extend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 7, 2024
1 parent 7c0aa8e commit 91095b9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
26 changes: 14 additions & 12 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ def __init__(
quantization_config=quantization_config,
**kwargs,
)
self._legacy_processing = not hasattr(self.config, "image_seq_length")
self._support_new_processing = hasattr(self.config, "image_seq_length")
self._legacy_processing = not self._support_new_processing

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
Expand Down Expand Up @@ -758,9 +759,7 @@ def merge_vision_text_embeddings(
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing is None:
legacy_processing = not (hasattr(self.config, "image_seq_length") and (input_ids.shape[-1] == 1)) or (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
legacy_processing = self._legacy_processing

if legacy_processing:
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down Expand Up @@ -840,20 +839,19 @@ def merge_vision_text_embeddings(
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
legacy_processing = self._legacy_processing
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)

if pixel_values is not None and not legacy_processing and past_key_values is None:
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=self._legacy_processing, **kwargs
)

if legacy_processing and pixel_values is not None and past_key_values is not None:
if self._legacy_processing and pixel_values is not None and past_key_values is not None:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)

return inputs_embeds, attention_mask, position_ids
Expand Down Expand Up @@ -966,9 +964,8 @@ def get_multimodal_embeddings(
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches

inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
legacy_processing = self._legacy_processing

if pixel_values is not None and not legacy_processing and past_key_values is None:
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
Expand Down Expand Up @@ -1010,11 +1007,16 @@ def get_multimodal_embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
legacy_processing=legacy_processing,
legacy_processing=self._legacy_processing,
**kwargs,
)

if legacy_processing and pixel_values is not None and past_key_values is not None and input_ids.shape[1] == 1:
if (
self._legacy_processing
and pixel_values is not None
and past_key_values is not None
and input_ids.shape[1] == 1
):
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)

return inputs_embeds, attention_mask, position_ids
Expand Down
33 changes: 33 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1984,6 +1984,39 @@ def test_compare_to_transformers(self, model_arch):
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)

# previous run was with legacy processing, one more run with features concatenation on preprocessing level
if (
model_arch in ["llava", "llava-next"]
and is_transformers_version(">=", "4.45")
and (processor.patch_size is None or processor.vision_feature_select_strategy is None)
):
processor.patch_size = ov_model.config.vision_config.patch_size
processor.vision_feature_select_strategy = ov_model.config.vision_feature_select_strategy
if model_arch == "llava":
# testing model for llava does ot have specified image_seq_length and it is different from default
transformers_model.config.image_seq_length = 225
ov_model.config.image_seq_length = 225
self.assertTrue(processor.patch_size is not None)
self.assertTrue(processor.vision_feature_select_strategy is not None)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
self.assertTrue(
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max()
>= ov_model.config.image_seq_length
)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
set_seed(SEED)
ov_outputs = ov_model(**inputs)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
set_seed(SEED)
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
set_seed(SEED)
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
self.assertTrue(
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)
del transformers_model
del ov_model

Expand Down

0 comments on commit 91095b9

Please sign in to comment.