Skip to content

Commit

Permalink
fix images processing
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 29, 2024
1 parent 8623b97 commit 7862263
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,16 @@ def forward(self, pixel_values, **kwargs):
)


MODEL_PARTS_CLS_MAPPING = {}
class OVVisionProjection(OVModelPart):
_model_name = "vision_projection"

def forward(self, img_features):
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {
"vision_projection": OVVisionProjection
}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
Expand Down Expand Up @@ -1141,11 +1150,11 @@ def __init__(
):
super().__init__(language_model, text_embeddings, vision_embeddings, config, device, dynamic_shapes, ov_config, model_save_dir, quantization_config, **kwargs)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)

def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs):
num_images, num_crops, c, h, w = pixel_values.shape
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(num_images, num_crops, -1, self.config.config.img_processor['image_dim_out'])
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(num_images, num_crops, -1, self.config.img_processor['image_dim_out'])
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj

Expand All @@ -1154,6 +1163,7 @@ def hd_feature_transform(self, image_features, image_sizes):
image_features: (num_images, num_crops+1, 24*24, 1024)
"""

image_features = torch.from_numpy(image_features)
global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
Expand Down Expand Up @@ -1184,10 +1194,9 @@ def hd_feature_transform(self, image_features, image_sizes):
global_image_features_hd_newline[i],
]
)

image_features_proj = self.img_projection(
torch.cat(all_image_embeddings, dim=0)
)
image_features_proj = self.vision_projection(
torch.cat(all_image_embeddings, dim=0).unsqueeze(0)
)[0]

return image_features_proj

Expand Down Expand Up @@ -1244,9 +1253,7 @@ def get_multimodal_embeddings(
if has_image:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds.index_put(
positions, image_features_proj, accumulate=False
)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)

return inputs_embeds, attention_mask, position_ids

Expand All @@ -1255,4 +1262,5 @@ def get_multimodal_embeddings(
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM
}

0 comments on commit 7862263

Please sign in to comment.