Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful models by default #486

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d41110c
Allow loading of stateful models (no patching yet)
slyalin Dec 5, 2023
bcb7cac
Stateful models support
slyalin Dec 7, 2023
d4c165b
Fix forward for chatglm
slyalin Dec 7, 2023
a62dc6f
Passing stateful as a dedicated parameter
slyalin Dec 8, 2023
403adb5
Fixed possibly misaligned types in ShapeOf Concat sub-expression
slyalin Dec 8, 2023
6097bfd
Merge remote-tracking branch 'origin/main' into stateful
slyalin Dec 12, 2023
c54a466
Fixed critical typo in infer_request invocation
slyalin Dec 12, 2023
6151bec
Merge remote-tracking branch 'origin/main' into stateful
slyalin Dec 13, 2023
0e5aefc
Apply bettertransfomer when model is converted in stateful mode
slyalin Dec 13, 2023
5087f92
Correct default value handling for stateful flag
slyalin Dec 13, 2023
5ae6d5c
Apply bettertransformer under try-except to avoid crashes when model …
slyalin Dec 13, 2023
ddb182e
Added --stateful option in optimum-cli
slyalin Dec 13, 2023
8e9a7e0
Merge remote-tracking branch 'origin/main' into stateful
slyalin Dec 14, 2023
a51ab27
Raise if too old version of opevino is used ans stateful=True
slyalin Dec 14, 2023
6df798a
Fix openvino version check to be compatible with openvino-nightly
slyalin Dec 14, 2023
fa33784
Merged from recent main branch
slyalin Dec 19, 2023
aa299e6
Fix for bloom family
slyalin Dec 19, 2023
70ff804
Allow loading of stateful models (no patching yet)
slyalin Dec 5, 2023
bf45cc7
Stateful models support
slyalin Dec 7, 2023
fafc040
Fix forward for chatglm
slyalin Dec 7, 2023
e4585fe
Passing stateful as a dedicated parameter
slyalin Dec 8, 2023
a5c48c4
Fixed possibly misaligned types in ShapeOf Concat sub-expression
slyalin Dec 8, 2023
3ce4fe9
Fixed critical typo in infer_request invocation
slyalin Dec 12, 2023
00bd0b3
Apply bettertransfomer when model is converted in stateful mode
slyalin Dec 13, 2023
489799a
Correct default value handling for stateful flag
slyalin Dec 13, 2023
312f24f
Apply bettertransformer under try-except to avoid crashes when model …
slyalin Dec 13, 2023
4bca339
Added --stateful option in optimum-cli
slyalin Dec 13, 2023
a5f8558
Raise if too old version of opevino is used ans stateful=True
slyalin Dec 14, 2023
2763639
Fix openvino version check to be compatible with openvino-nightly
slyalin Dec 14, 2023
f5da152
Fix for bloom family
slyalin Dec 19, 2023
da62d99
Fix general code style and appliy renaming suggestions
eaidova Dec 21, 2023
33f86dc
fix version checking if openvino not in site-packages
eaidova Dec 21, 2023
7d15415
use reset_stateif available
eaidova Dec 21, 2023
0dfbb63
remove input patch in bettertransformer apply
eaidova Dec 21, 2023
09bbe7b
add tests
eaidova Dec 21, 2023
e70a50d
add type hints and update doc strings
eaidova Dec 21, 2023
a57c86a
added more tests
eaidova Dec 25, 2023
67deae2
Merge remote-tracking branch 'origin/main' into ea/stateful
slyalin Jan 4, 2024
89ba0cb
Fixed outdated signature of InferRequest wrapper to fix one of the qu…
slyalin Jan 4, 2024
a85960a
Merge remote-tracking branch 'slyalin/stateful' into stateful
slyalin Jan 4, 2024
dcfcc2a
Merge remote-tracking branch 'origin/main' into stateful
slyalin Jan 5, 2024
df0e727
Switch to stateful model by default
slyalin Jan 5, 2024
9992419
Merge remote-tracking branch 'origin/main' into stateful
slyalin Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def parse_args_openvino(parser: "ArgumentParser"):
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"),
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"),
optional_group.add_argument(
"--stateful",
action="store_true",
help="Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs"
),


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -106,5 +111,6 @@ def run(self):
pad_token_id=self.args.pad_token_id,
fp16=self.args.fp16,
int8=self.args.int8,
stateful=self.args.stateful,
# **input_shapes,
)
1 change: 1 addition & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .__main__ import main_export
from .convert import export, export_models, export_pytorch_via_onnx
from .stateful import patch_stateful, raise_if_openvino_is_too_old


__all__ = ["main_export", "export", "export_models"]
2 changes: 2 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def main_export(
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = None,
stateful: Optional[bool] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -350,6 +351,7 @@ class StoreAttr(object):
device=device,
fp16=fp16,
int8=int8,
stateful=stateful,
model_kwargs=model_kwargs,
)

Expand Down
41 changes: 41 additions & 0 deletions optimum/exporters/openvino/better_transformer_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import types


def patch_model_with_bettertransformer(model, model_config):
try:
model = model.to_bettertransformer()
except Exception as e:
print(f'[ WARNING ] Cannot apply model.to_bettertransformer because of the exception:\n{e}')
return model

# for better transformers we need sequence lenght to be not 1 to make a correct trace
# patch generate_dummy_inputs in the config

def pathed_generate_dummy_inputs(self, *args, **kwargs):
dummy_inputs = self._original_generate_dummy_inputs(*args, **kwargs)
if 'input_ids' in dummy_inputs and dummy_inputs['input_ids'].shape[1] == 1:
dummy_inputs['input_ids'] = torch.cat([dummy_inputs['input_ids'], dummy_inputs['input_ids']], dim=-1)
attention_mask = dummy_inputs['attention_mask']
dummy_inputs['attention_mask'] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
return dummy_inputs

model_config._original_generate_dummy_inputs = model_config.generate_dummy_inputs
model_config.generate_dummy_inputs = types.MethodType(pathed_generate_dummy_inputs, model_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we use input_shapes instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it can be used now. When this code was first written it was applied externally as a patch for the model. Now it is most likely can be implemented directly in export function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will not work, because there is additional logic that modify sequence_len in OnnxConfigWithPast after original shapes specification
https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/base.py#L663

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starting with optimum 1.14.0 this behaviour triggered only by legacy path, but we do not use it, so it can be removed at all


return model
30 changes: 30 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
from optimum.utils import is_diffusers_available
from .stateful import patch_stateful, raise_if_openvino_is_too_old
from .better_transformer_patch import patch_model_with_bettertransformer

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
from .utils import (
Expand Down Expand Up @@ -77,6 +79,7 @@ def export(
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
stateful: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -120,6 +123,7 @@ def export(
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
stateful=stateful,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand Down Expand Up @@ -232,6 +236,7 @@ def export_pytorch(
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
stateful: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -263,6 +268,15 @@ def export_pytorch(
logger.info(f"Using framework PyTorch: {torch.__version__}")
output = Path(output)

if stateful:
# Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect
# both of them are applied to demonstrate the best performance.
# TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation.
model = patch_model_with_bettertransformer(model, config)
# TODO: Consider unpatching model after export is done in the end of this function.
# Now it is left as-is because the model is not expected to be used after call export_pytorch, and
# this function is one of the _internal_ steps in a bigger model conversion pipeline.

with torch.no_grad():
model.config.torchscript = False
model.config.return_dict = True
Expand Down Expand Up @@ -341,6 +355,10 @@ def ts_patched_forward(*args, **kwargs):
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
if patch_model_forward:
model.forward = orig_forward
if stateful:
raise ValueError(
'Making stateful models is not supported when exporting to ONNX as an intermediate step. '
'Set stateful=False, or provide a model that can be converted to OpenVINO without fallback to ONNX conversion path.')
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8
)
Expand All @@ -364,6 +382,13 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()

if stateful:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also check task before applying patch? for making this transformation, the model should have with_past in task type and if I right understand current changes are targeted only for text-generation-with-past, there is also seq2seq models with decoder with past as @AlexKoff88 already mentioned in comment, where this transformation should be applied on only one from 3 models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is implemented now, it is a user responsibility to set stateful=True only for right models. To set it by default there should be more adjustments in the code, which I cannot provide in this PR. Part of the models will not be supported. I am just not so familiar with the code base to provide necessary changes to enable by default. We need someone who can do that.

# Patching model according to stateful parameters
model.key_value_input_names = [name for name in input_names if name.startswith('past_key_values.')]
model.key_value_output_names = [name for name in output_names if name.startswith('present.')]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be done in patch_stateful directly ? (without modifying the original model attributes, but by checking the model graph inputs or by also providing the input_names / output_names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewritten in #493

patch_stateful(model, ov_model)

_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8)
clear_class_registry()
del model
Expand All @@ -383,6 +408,7 @@ def export_models(
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
stateful: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand All @@ -406,6 +432,9 @@ def export_models(
Returns:
list of input_names and output_names from ONNX configuration
"""
if stateful:
# This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
raise_if_openvino_is_too_old()
outputs = []

if output_names is not None and len(output_names) != len(models_and_onnx_configs):
Expand All @@ -429,6 +458,7 @@ def export_models(
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
stateful=stateful,
)
)

Expand Down
146 changes: 146 additions & 0 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
from packaging import version
import openvino as ov
from openvino.runtime import opset13
from optimum.intel.utils.import_utils import is_openvino_version


def model_has_name(ov_model: ov.Model, name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def model_has_name(ov_model: ov.Model, name: str):
def model_has_input_output_name(ov_model: ov.Model, name: str):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed in #493

return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], list())


def model_has_input(ov_model: ov.Model, name: str):
return name in sum([list(t.get_names()) for t in ov_model.inputs], list())


def model_has_cache_reorder(ov_model):
return model_has_input(ov_model, 'beam_idx')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should be beam_id or beam_ids if you count batch to stay aligned with input_ids

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind. It's reasonable to state that idx is singular form of the word ids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't super creative when choosing this name, and just borrowed it from _reorder_cache method argument name. So this is aligned with that part of the code that this thing comes from. And this idx is not the same as those id or ids they are indexes for a different thing.



def model_has_state(ov_model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def model_has_state(ov_model):
def _model_has_state(ov_model):

Would add an underscore to most of the added functions to hint that it's not intended for external use (to give us more freedom in the future)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to use this method externally, so that is why naming without underscore is preferable for us.

TO DO inside means that possibly the realization changed in future, but as part of API, this function is very useful

# TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
return len(ov_model.get_sinks()) > 0


def fuse_cache_reorder(ov_model: ov.Model, not_kv_inputs, key_value_input_names, gather_dim: int):
""" Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
Should be run before make_stateful. Implements optimumum's _reorder_cache
inside the model in the beginning of each iteration.
Gather works along given gather_dim dimension that may vary from model to model.
KV-cache inputs are identified based on names in key_value_input_names.
Append the new beam_idx parameter to not_kv_inputs.
"""

assert not model_has_name(ov_model, 'beam_idx')
input_batch = ov_model.input('input_ids').get_partial_shape()[0]
beam_idx = opset13.parameter(name='beam_idx', dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
beam_idx.output(0).get_tensor().add_names({'beam_idx'}) # why list is not accepted?
ov_model.add_parameters([beam_idx])
not_kv_inputs.append(ov_model.inputs[-1])
# Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
for input_name in key_value_input_names:
parameter_output_port = ov_model.input(input_name)
consumers = parameter_output_port.get_target_inputs()
gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
for consumer in consumers:
consumer.replace_source_output(gather.output(0))
ov_model.validate_nodes_and_infer_types()


def build_state_initializer(ov_model: ov.Model, batch_dim):
"""Build initialization ShapeOf Expression for all ReadValue ops"""
input_ids = ov_model.input('input_ids')
batch = opset13.gather(opset13.shape_of(input_ids, output_type='i64'), opset13.constant([0]), opset13.constant(0))
for op in ov_model.get_ops():
if op.get_type_name() == 'ReadValue':
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
dims[batch_dim] = batch
dims = [opset13.constant(np.array([dim], dtype=np.int64)) if type(dim) is int else dim for dim in dims]
shape = opset13.concat(dims, axis=0)
broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
op.set_arguments([broadcast])
ov_model.validate_nodes_and_infer_types()


def make_stateful(
ov_model: ov.Model,
not_kv_inputs,
key_value_input_names,
key_value_output_names,
batch_dim,
num_attention_heads,
num_beams_and_batch=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add type hinting ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in #493

""" Hides kv-cache inputs and outputs inside the model as variables.
"""
from openvino._offline_transformations import apply_make_stateful_transformation

input_output_map = {}
# TODO: Can we derive the dimensions from the model topology?

if num_beams_and_batch is not None:
# Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
for input in not_kv_inputs:
shape = input.get_partial_shape()
if shape.rank.get_length() <= 2: # == 1 for beam_index
shape[0] = num_beams_and_batch
input.get_node().set_partial_shape(shape)
else:
print(f'[ WARNING ] Rank of {input.get_any_name()} input of the model is not 2, batch size is not set')

for kv_name_pair in zip(key_value_input_names, key_value_output_names):
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
if num_beams_and_batch is not None:
input = ov_model.input(kv_name_pair[0])
shape = input.get_partial_shape()
shape[batch_dim] = num_beams_and_batch * num_attention_heads
input.get_node().set_partial_shape(shape)

if num_beams_and_batch is not None:
# Re-validation model if shapes are altered above
ov_model.validate_nodes_and_infer_types()

apply_make_stateful_transformation(ov_model, input_output_map)
if num_beams_and_batch is None:
build_state_initializer(ov_model, batch_dim)


def raise_if_openvino_is_too_old():
if is_openvino_version("<=", "2023.2"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That prohibits https://storage.openvinotoolkit.org/repositories/openvino/packages/nightly/2023.3.0-13432-a6ea22ad0e6/l_openvino_toolkit_ubuntu20_2023.3.0.dev20231129_x86_64.tgz usage. I get ValueError: Could not create or use stateful model when using old version of openvino==2023.3.0-13432-a6ea22ad0e6. Install openvino>=2023.3.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I remember the previous version of this condition was different and didn't work with nightly. I have changed it and tested with nightly -- everything worked. Not sure why it doesn't work now -- probably I messed up with the versions when checked on my side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eaidova reproduced this issue when both openvino and openvino-nightly installed. @Wovchena please uninstall the packages and install openvino-nightly only.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we need to fix it to work with different OpenVINO installations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no issues with detection version using python libs from archive. The problem looks like happens in case if you have installed openvino from multiple sources in the same time (e.g. pypi openvino + pypi openvino-nightly or pypi openvino + PYTHONPATH to openvino libs) )and their versions are different

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my PR with fixes for this branch,
#493
I changed logic for version verification. previously used approach checked only packages that exists in site-packages of python environment (only installed as wheels), but does not take into account that user (in development case for example) install openvino as archive or build from source without installing wheels.

Now, version openvino should be aligned with import order

raise ValueError(f'Could not create or use stateful model when using old version of openvino=={ov.__version__}. Install openvino>=2023.3.0.')


def patch_stateful(model, ov_model):
raise_if_openvino_is_too_old()
not_kv_inputs = [input for input in ov_model.inputs if not any(name in model.key_value_input_names for name in input.get_names())]

# By default, batch is the 0-th but chatglm uses 1-st dimension as batch
# TODO: Deduce from a model via ordinal reshape (?) and topology
batch_dim = 1 if model.config.model_type == 'chatglm' else 0

fuse_cache_reorder(ov_model, not_kv_inputs, model.key_value_input_names, batch_dim)

num_attention_heads = model.normalized_config.num_attention_heads if model.config.model_type == 'bloom' else 1

make_stateful(
ov_model,
not_kv_inputs,
model.key_value_input_names,
model.key_value_output_names,
batch_dim,
num_attention_heads,
None)
Loading