Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sharing memory flags for OpenVINO #429

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,13 @@ def forward(
inputs["attention_mask"] = np.array(attention_mask)

# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
results = self.request.infer(inputs, share_inputs=True, share_outputs=True)

logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
logits = torch.from_numpy(results["logits"]).to(self.device)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
past_key_values = tuple(results[key] for key in self.key_value_output_names)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def __call__(self, input_ids: np.ndarray):
inputs = {
"input_ids": input_ids,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True, share_outputs=True)
return list(outputs.values())


Expand Down Expand Up @@ -587,7 +587,7 @@ def __call__(
if time_ids is not None:
inputs["time_ids"] = time_ids

outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True, share_outputs=True)
return list(outputs.values())


Expand All @@ -603,7 +603,7 @@ def __call__(self, latent_sample: np.ndarray):
inputs = {
"latent_sample": latent_sample,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True, share_outputs=True)
return list(outputs.values())

def _compile(self):
Expand All @@ -624,7 +624,7 @@ def __call__(self, sample: np.ndarray):
inputs = {
"sample": sample,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True, share_outputs=True)
return list(outputs.values())

def _compile(self):
Expand Down
9 changes: 4 additions & 5 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to(
last_hidden_state = torch.from_numpy(self.request(inputs, share_inputs=True)["last_hidden_state"]).to(
self.device
)

Expand Down Expand Up @@ -413,13 +413,12 @@ def forward(
if "encoder_hidden_states" in self.input_names and encoder_hidden_states is not None:
inputs["encoder_hidden_states"] = encoder_hidden_states
# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
results = self.request.infer(inputs, share_inputs=True, share_outputs=True)
logits = torch.from_numpy(results["logits"]).to(self.device)

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
out_past_key_values = tuple(results[key] for key in self.key_value_output_names)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
Expand Down
Loading