Skip to content

Commit

Permalink
update legacy processing path
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 7, 2024
1 parent 91095b9 commit bb90a76
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 86 deletions.
67 changes: 15 additions & 52 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def __init__(
**kwargs,
)
self._support_new_processing = hasattr(self.config, "image_seq_length")
self._legacy_processing = not self._support_new_processing

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
Expand Down Expand Up @@ -753,13 +752,11 @@ def merge_vision_text_embeddings(
input_ids,
attention_mask,
position_ids=None,
legacy_processing=None,
legacy_processing=False,
**kwargs,
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing is None:
legacy_processing = self._legacy_processing

if legacy_processing:
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down Expand Up @@ -792,15 +789,6 @@ def merge_vision_text_embeddings(
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
Expand All @@ -811,15 +799,15 @@ def merge_vision_text_embeddings(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model a/pre-releasesre wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

Expand All @@ -839,56 +827,35 @@ def merge_vision_text_embeddings(
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)

if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

else:
legacy_processing = True
inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=self._legacy_processing, **kwargs
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
)

if self._legacy_processing and pixel_values is not None and past_key_values is not None:
if legacy_processing and pixel_values is not None and past_key_values is not None:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)

return inputs_embeds, attention_mask, position_ids

def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
if not self.language_model.stateful:
first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
else:
first_layer_past_key_value = torch.from_numpy(
self.language_model.request.query_state()[0].state.data[:, :, :, 0]
)

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
past_length = self.language_model._get_past_length(past_key_values)

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
position_ids = torch.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
return attention_mask, position_ids


Expand Down Expand Up @@ -969,7 +936,8 @@ def get_multimodal_embeddings(
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing
else:
legacy_processing = True

if pixel_values is not None and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
Expand Down Expand Up @@ -1007,16 +975,11 @@ def get_multimodal_embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
legacy_processing=self._legacy_processing,
legacy_processing=legacy_processing,
**kwargs,
)

if (
self._legacy_processing
and pixel_values is not None
and past_key_values is not None
and input_ids.shape[1] == 1
):
if legacy_processing and pixel_values is not None and past_key_values is not None and input_ids.shape[1] == 1:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)

return inputs_embeds, attention_mask, position_ids
Expand All @@ -1029,7 +992,7 @@ def merge_vision_text_embeddings(
input_ids,
attention_mask,
position_ids=None,
legacy_processing=None,
legacy_processing=False,
**kwargs,
):
image_token_index = self.config.image_token_index
Expand Down
94 changes: 60 additions & 34 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,45 +1983,71 @@ def test_compare_to_transformers(self, model_arch):
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)

# previous run was with legacy processing, one more run with features concatenation on preprocessing level
if (
model_arch in ["llava", "llava-next"]
and is_transformers_version(">=", "4.45")
and (processor.patch_size is None or processor.vision_feature_select_strategy is None)
):
processor.patch_size = ov_model.config.vision_config.patch_size
processor.vision_feature_select_strategy = ov_model.config.vision_feature_select_strategy
if model_arch == "llava":
# testing model for llava does ot have specified image_seq_length and it is different from default
transformers_model.config.image_seq_length = 225
ov_model.config.image_seq_length = 225
self.assertTrue(processor.patch_size is not None)
self.assertTrue(processor.vision_feature_select_strategy is not None)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
self.assertTrue(
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max()
>= ov_model.config.image_seq_length
)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
set_seed(SEED)
ov_outputs = ov_model(**inputs)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
set_seed(SEED)
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
set_seed(SEED)
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
self.assertTrue(
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)
del transformers_model
del ov_model

gc.collect()

@unittest.skipIf(
is_transformers_version("<", "4.45.0"), reason="New preprocessing available only in transformers >= 4.45"
)
@parameterized.expand(["llava", "llava_next"])
def test_llava_with_new_preprocessing(self, model_arch):
prompt = "<image>\n What is shown in this image?"
model_id = MODEL_NAMES[model_arch]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
processor = AutoProcessor.from_pretrained(
model_id,
patch_size=config.vision_config.patch_size,
vision_feature_select_strategy=config.vision_feature_select_strategy,
trust_remote_code=model_arch in self.REMOTE_CODE_MODELS,
)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id)
ov_model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
self.assertTrue(ov_model._support_new_processing)
if model_arch == "llava":
# testing model for llava does ot have specified image_seq_length and it is different from default
transformers_model.config.image_seq_length = 225
ov_model.config.image_seq_length = 225
self.assertTrue(processor.patch_size is not None)
self.assertTrue(processor.vision_feature_select_strategy is not None)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
self.assertTrue(
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max() >= ov_model.config.image_seq_length
)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
set_seed(SEED)
ov_outputs = ov_model(**inputs)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
transformers_model.config.eos_token_id = None
gen_config = GenerationConfig(
max_new_tokens=30,
min_new_tokens=30,
num_beams=3,
do_sample=False,
eos_token_id=None,
)
set_seed(SEED)
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
self.assertTrue(
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)

del ov_model
del transformers_model
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch):
model_id = MODEL_NAMES[model_arch]
Expand Down

0 comments on commit bb90a76

Please sign in to comment.