-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateful models by default #486
Changes from 7 commits
d41110c
bcb7cac
d4c165b
a62dc6f
403adb5
6097bfd
c54a466
6151bec
0e5aefc
5087f92
5ae6d5c
ddb182e
8e9a7e0
a51ab27
6df798a
fa33784
aa299e6
70ff804
bf45cc7
fafc040
e4585fe
a5c48c4
3ce4fe9
00bd0b3
489799a
312f24f
4bca339
a5f8558
2763639
f5da152
da62d99
33f86dc
7d15415
0dfbb63
09bbe7b
e70a50d
a57c86a
67deae2
89ba0cb
a85960a
dcfcc2a
df0e727
9992419
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .__main__ import main_export | ||
from .convert import export, export_models, export_pytorch_via_onnx | ||
from .stateful import patch_stateful | ||
|
||
|
||
__all__ = ["main_export", "export", "export_models"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx | ||
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher | ||
from optimum.utils import is_diffusers_available | ||
from .stateful import patch_stateful | ||
|
||
from ...intel.utils.import_utils import is_nncf_available, is_optimum_version | ||
from .utils import ( | ||
|
@@ -77,6 +78,7 @@ def export( | |
model_kwargs: Optional[Dict[str, Any]] = None, | ||
fp16: bool = False, | ||
int8: bool = False, | ||
stateful: bool = False, | ||
) -> Tuple[List[str], List[str]]: | ||
""" | ||
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. | ||
|
@@ -120,6 +122,7 @@ def export( | |
model_kwargs=model_kwargs, | ||
fp16=fp16, | ||
int8=int8, | ||
stateful=stateful, | ||
) | ||
|
||
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): | ||
|
@@ -232,6 +235,7 @@ def export_pytorch( | |
model_kwargs: Optional[Dict[str, Any]] = None, | ||
fp16: bool = False, | ||
int8: bool = False, | ||
stateful: bool = False, | ||
) -> Tuple[List[str], List[str]]: | ||
""" | ||
Exports a PyTorch model to an OpenVINO Intermediate Representation. | ||
|
@@ -364,6 +368,13 @@ def ts_patched_forward(*args, **kwargs): | |
inp_tensor.get_node().set_partial_shape(static_shape) | ||
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) | ||
ov_model.validate_nodes_and_infer_types() | ||
|
||
if stateful: | ||
# Patching model according to stateful parameters | ||
model.key_value_input_names = [name for name in input_names if name.startswith('past_key_values.')] | ||
model.key_value_output_names = [name for name in output_names if name.startswith('present.')] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could this be done in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rewritten in #493 |
||
patch_stateful(model, ov_model) | ||
|
||
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8) | ||
clear_class_registry() | ||
del model | ||
|
@@ -383,6 +394,7 @@ def export_models( | |
model_kwargs: Optional[Dict[str, Any]] = None, | ||
fp16: bool = False, | ||
int8: bool = False, | ||
stateful: bool = False, | ||
) -> Tuple[List[List[str]], List[List[str]]]: | ||
""" | ||
Export the models to OpenVINO IR format | ||
|
@@ -429,6 +441,7 @@ def export_models( | |
model_kwargs=model_kwargs, | ||
fp16=fp16, | ||
int8=int8, | ||
stateful=stateful, | ||
) | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,137 @@ | ||||||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
|
||||||
import openvino as ov | ||||||
from openvino.runtime import opset13 | ||||||
import numpy as np | ||||||
|
||||||
|
||||||
def model_has_name(ov_model: ov.Model, name: str): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed in #493 |
||||||
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], list()) | ||||||
|
||||||
|
||||||
def model_has_input(ov_model: ov.Model, name: str): | ||||||
return name in sum([list(t.get_names()) for t in ov_model.inputs], list()) | ||||||
|
||||||
|
||||||
def model_has_cache_reorder(ov_model): | ||||||
return model_has_input(ov_model, 'beam_idx') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never mind. It's reasonable to state that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't super creative when choosing this name, and just borrowed it from |
||||||
|
||||||
|
||||||
def model_has_state(ov_model): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Would add an underscore to most of the added functions to hint that it's not intended for external use (to give us more freedom in the future) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to use this method externally, so that is why naming without underscore is preferable for us. TO DO inside means that possibly the realization changed in future, but as part of API, this function is very useful |
||||||
# TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods | ||||||
return len(ov_model.get_sinks()) > 0 | ||||||
|
||||||
|
||||||
def fuse_cache_reorder(ov_model: ov.Model, not_kv_inputs, key_value_input_names, gather_dim: int): | ||||||
""" Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model. | ||||||
Should be run before make_stateful. Implements optimumum's _reorder_cache | ||||||
inside the model in the beginning of each iteration. | ||||||
Gather works along given gather_dim dimension that may vary from model to model. | ||||||
KV-cache inputs are identified based on names in key_value_input_names. | ||||||
Append the new beam_idx parameter to not_kv_inputs. | ||||||
""" | ||||||
|
||||||
assert not model_has_name(ov_model, 'beam_idx') | ||||||
input_batch = ov_model.input('input_ids').get_partial_shape()[0] | ||||||
beam_idx = opset13.parameter(name='beam_idx', dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) | ||||||
beam_idx.output(0).get_tensor().add_names({'beam_idx'}) # why list is not accepted? | ||||||
ov_model.add_parameters([beam_idx]) | ||||||
not_kv_inputs.append(ov_model.inputs[-1]) | ||||||
# Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx | ||||||
for input_name in key_value_input_names: | ||||||
parameter_output_port = ov_model.input(input_name) | ||||||
consumers = parameter_output_port.get_target_inputs() | ||||||
gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) | ||||||
for consumer in consumers: | ||||||
consumer.replace_source_output(gather.output(0)) | ||||||
ov_model.validate_nodes_and_infer_types() | ||||||
|
||||||
|
||||||
def build_state_initializer(ov_model: ov.Model, batch_dim): | ||||||
"""Build initialization ShapeOf Expression for all ReadValue ops""" | ||||||
input_ids = ov_model.input('input_ids') | ||||||
batch = opset13.gather(opset13.shape_of(input_ids, output_type='i64'), opset13.constant([0]), opset13.constant(0)) | ||||||
for op in ov_model.get_ops(): | ||||||
if op.get_type_name() == 'ReadValue': | ||||||
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] | ||||||
dims[batch_dim] = batch | ||||||
dims = [opset13.constant(np.array([dim], dtype=np.int64)) if type(dim) is int else dim for dim in dims] | ||||||
shape = opset13.concat(dims, axis=0) | ||||||
broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) | ||||||
op.set_arguments([broadcast]) | ||||||
ov_model.validate_nodes_and_infer_types() | ||||||
|
||||||
|
||||||
def make_stateful( | ||||||
ov_model: ov.Model, | ||||||
not_kv_inputs, | ||||||
key_value_input_names, | ||||||
key_value_output_names, | ||||||
batch_dim, | ||||||
num_attention_heads, | ||||||
num_beams_and_batch=None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add type hinting ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in #493 |
||||||
""" Hides kv-cache inputs and outputs inside the model as variables. | ||||||
""" | ||||||
from openvino._offline_transformations import apply_make_stateful_transformation | ||||||
|
||||||
input_output_map = {} | ||||||
# TODO: Can we derive the dimensions from the model topology? | ||||||
|
||||||
if num_beams_and_batch is not None: | ||||||
# Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue | ||||||
for input in not_kv_inputs: | ||||||
shape = input.get_partial_shape() | ||||||
if shape.rank.get_length() <= 2: # == 1 for beam_index | ||||||
shape[0] = num_beams_and_batch | ||||||
input.get_node().set_partial_shape(shape) | ||||||
else: | ||||||
print(f'[ WARNING ] Rank of {input.get_any_name()} input of the model is not 2, batch size is not set') | ||||||
|
||||||
for kv_name_pair in zip(key_value_input_names, key_value_output_names): | ||||||
input_output_map[kv_name_pair[0]] = kv_name_pair[1] | ||||||
if num_beams_and_batch is not None: | ||||||
input = ov_model.input(kv_name_pair[0]) | ||||||
shape = input.get_partial_shape() | ||||||
shape[batch_dim] = num_beams_and_batch * num_attention_heads | ||||||
input.get_node().set_partial_shape(shape) | ||||||
|
||||||
if num_beams_and_batch is not None: | ||||||
# Re-validation model if shapes are altered above | ||||||
ov_model.validate_nodes_and_infer_types() | ||||||
|
||||||
apply_make_stateful_transformation(ov_model, input_output_map) | ||||||
if num_beams_and_batch is None: | ||||||
build_state_initializer(ov_model, batch_dim) | ||||||
|
||||||
def patch_stateful(model, ov_model): | ||||||
not_kv_inputs = [input for input in ov_model.inputs if not any(name in model.key_value_input_names for name in input.get_names())] | ||||||
|
||||||
# By default, batch is the 0-th but chatglm uses 1-st dimension as batch | ||||||
# TODO: Deduce from a model via ordinal reshape (?) and topology | ||||||
batch_dim = 1 if model.config.model_type == 'chatglm' else 0 | ||||||
|
||||||
fuse_cache_reorder(ov_model, not_kv_inputs, model.key_value_input_names, batch_dim) | ||||||
|
||||||
num_attention_heads = model.normalized_config.num_attention_heads if model.config.model_type == 'bloom' else 1 | ||||||
|
||||||
make_stateful( | ||||||
ov_model, | ||||||
not_kv_inputs, | ||||||
model.key_value_input_names, | ||||||
model.key_value_output_names, | ||||||
batch_dim, | ||||||
num_attention_heads, | ||||||
None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also check task before applying patch? for making this transformation, the model should have with_past in task type and if I right understand current changes are targeted only for text-generation-with-past, there is also seq2seq models with decoder with past as @AlexKoff88 already mentioned in comment, where this transformation should be applied on only one from 3 models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it is implemented now, it is a user responsibility to set stateful=True only for right models. To set it by default there should be more adjustments in the code, which I cannot provide in this PR. Part of the models will not be supported. I am just not so familiar with the code base to provide necessary changes to enable by default. We need someone who can do that.