Skip to content

Commit

Permalink
fix device selection for compilation language model in vlm
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 24, 2024
1 parent 86598a6 commit 9842cfb
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,25 @@ def __init__(

def compile(self):
if self.request is None:
logger.info(f"Compiling the Language model to {self._device} ...")
self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request()
if self._compile_only:
self.request = self.model.create_infer_request()
else:
logger.info(f"Compiling the Language model to {self._device} ...")
self.request = self._compile_model(
self.model, self._device, self.ov_config, self.model_save_dir
).create_infer_request()
self._compile_text_emb()

def _compile_text_emb(self):
if self.text_emb_request is None:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = core.compile_model(self.text_emb_model, self._device, self.ov_config)
if self._compile_only:
self.text_emb_request = self.text_emb_model
else:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = self._compile_model(
self.text_emb_model, self._device, self.ov_config, self.model_save_dir
)

def clear_requests(self):
if self._compile_only:
Expand Down Expand Up @@ -122,8 +133,8 @@ def prepare_inputs(
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_len:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]

inputs["position_ids"] = position_ids

Expand Down Expand Up @@ -240,7 +251,7 @@ def __init__(
self.lm_model,
self.text_embdings_model,
config=config,
deivce=device,
device=device,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
Expand Down

0 comments on commit 9842cfb

Please sign in to comment.