Skip to content

Commit

Permalink
Add OPenVINO stateful model support (#493)
Browse files Browse the repository at this point in the history
* Allow loading of stateful models (no patching yet)

* Stateful models support

* Fix forward for chatglm

* Passing stateful as a dedicated parameter

* Fixed possibly misaligned types in ShapeOf Concat sub-expression

* Fixed critical typo in infer_request invocation

* Apply bettertransfomer when model is converted in stateful mode

* Correct default value handling for stateful flag

* Apply bettertransformer under try-except to avoid crashes when model is not supported

* Added --stateful option in optimum-cli

* Raise if too old version of opevino is used ans stateful=True

* Fix openvino version check to be compatible with openvino-nightly

* Fix for bloom family

* Allow loading of stateful models (no patching yet)

* Stateful models support

* Fix forward for chatglm

* Passing stateful as a dedicated parameter

* Fixed possibly misaligned types in ShapeOf Concat sub-expression

* Fixed critical typo in infer_request invocation

* Apply bettertransfomer when model is converted in stateful mode

* Correct default value handling for stateful flag

* Apply bettertransformer under try-except to avoid crashes when model is not supported

* Added --stateful option in optimum-cli

* Raise if too old version of opevino is used ans stateful=True

* Fix openvino version check to be compatible with openvino-nightly

* Fix for bloom family

* Fix general code style and appliy renaming suggestions

* fix version checking if openvino not in site-packages

* use reset_stateif available

* remove input patch in bettertransformer apply

* add tests

* add type hints and update doc strings

* added more tests

* Fixed outdated signature of InferRequest wrapper to fix one of the quantizer tests.

* Switch to stateful model by default

* Allow loading of stateful models (no patching yet)

* Stateful models support

* Fix forward for chatglm

* Passing stateful as a dedicated parameter

* Fixed possibly misaligned types in ShapeOf Concat sub-expression

* Apply bettertransfomer when model is converted in stateful mode

* Correct default value handling for stateful flag

* Apply bettertransformer under try-except to avoid crashes when model is not supported

* Added --stateful option in optimum-cli

* Raise if too old version of opevino is used ans stateful=True

* Fix openvino version check to be compatible with openvino-nightly

* Fix for bloom family

* Fix general code style and appliy renaming suggestions

* fix version checking if openvino not in site-packages

* use reset_stateif available

* remove input patch in bettertransformer apply

* add tests

* add type hints and update doc strings

* added more tests

* Fixed outdated signature of InferRequest wrapper to fix one of the quantizer tests.

* Stateful models support

* Fix forward for chatglm

* Passing stateful as a dedicated parameter

* Apply bettertransfomer when model is converted in stateful mode

* Raise if too old version of opevino is used ans stateful=True

* Fix openvino version check to be compatible with openvino-nightly

* Fix for bloom family

* fix test and add beam_idx attribute

* apply review comments

* stateful by default fixes

* less agressive stateful

* ensure that task support stateful

* remove debug print

* Apply suggestions from code review

Co-authored-by: Sergey Lyalin <[email protected]>

* Apply suggestions from code review

Co-authored-by: Helena Kloosterman <[email protected]>

* update requirements and warning messages

* Apply suggestions from code review

Co-authored-by: Ella Charlaix <[email protected]>

* fix cli export

* Update optimum/exporters/openvino/__main__.py

Co-authored-by: Ella Charlaix <[email protected]>

---------

Co-authored-by: Sergey Lyalin <[email protected]>
Co-authored-by: Helena Kloosterman <[email protected]>
Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
4 people authored Jan 16, 2024
1 parent 133aa7d commit 7f236c2
Show file tree
Hide file tree
Showing 14 changed files with 573 additions and 91 deletions.
14 changes: 13 additions & 1 deletion optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def parse_args_openvino(parser: "ArgumentParser"):
"precision (by default 20%% in INT8). This helps to achieve better accuracy after weight compression."
),
)
optional_group.add_argument(
"--disable-stateful",
action="store_true",
help=(
"Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. "
"In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. "
"If --disable-stateful option is used, it may result in sub-optimal inference performance. "
"Use it when you intentionally want to use a stateless model, for example, to be compatible with existing "
"OpenVINO native inference code that expects kv-cache inputs and outputs in the model."
),
)


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -138,6 +149,7 @@ def run(self):
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
compression_option=self.args.weight_format,
compression_ratio=self.args.ratio
compression_ratio=self.args.ratio,
stateful=not self.args.disable_stateful,
# **input_shapes,
)
1 change: 1 addition & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .__main__ import main_export
from .convert import export, export_models, export_pytorch_via_onnx
from .stateful import ensure_stateful_is_available, patch_stateful


__all__ = ["main_export", "export", "export_models"]
8 changes: 8 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
from .convert import export_models
from .stateful import ensure_export_task_support_stateful


if is_optimum_version(">=", "1.16.0"):
Expand Down Expand Up @@ -65,6 +66,7 @@ def main_export(
fn_get_submodels: Optional[Callable] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = True,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -124,6 +126,8 @@ def main_export(
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`bool`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -277,6 +281,9 @@ class StoreAttr(object):
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")

task_support_stateful = ensure_export_task_support_stateful(task)
stateful = stateful and task_support_stateful

preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
Expand Down Expand Up @@ -373,6 +380,7 @@ class StoreAttr(object):
device=device,
compression_option=compression_option,
compression_ratio=compression_ratio,
stateful=stateful,
model_kwargs=model_kwargs,
)

Expand Down
47 changes: 46 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from optimum.utils import is_diffusers_available

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
from .model_patcher import patch_model_with_bettertransformer
from .stateful import ensure_stateful_is_available, patch_stateful
from .utils import (
OV_XML_FILE_NAME,
clear_class_registry,
Expand Down Expand Up @@ -102,6 +104,7 @@ def export(
model_kwargs: Optional[Dict[str, Any]] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = True,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand All @@ -125,6 +128,8 @@ def export(
Compression ratio between primary and backup precision (only relevant to INT4).
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the exporter.
stateful (`bool`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -139,6 +144,10 @@ def export(
if "diffusers" in str(model.__class__) and not is_diffusers_available():
raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.")

if stateful:
# This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
stateful = ensure_stateful_is_available()

if is_torch_available() and isinstance(model, nn.Module):
return export_pytorch(
model,
Expand All @@ -150,6 +159,7 @@ def export(
compression_option=compression_option,
compression_ratio=compression_ratio,
model_kwargs=model_kwargs,
stateful=stateful,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -160,7 +170,9 @@ def export(
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
return export_tensorflow(model, config, opset, output)
return export_tensorflow(
model, config, opset, output, compression_option=compression_option, compression_ratio=compression_ratio
)

else:
raise RuntimeError(
Expand Down Expand Up @@ -271,6 +283,7 @@ def export_pytorch(
model_kwargs: Optional[Dict[str, Any]] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand All @@ -291,6 +304,13 @@ def export_pytorch(
If specified, allows to use specific shapes for the example input provided to the exporter.
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
Additional kwargs for model export
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`bool`, defaults to `False`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
Returns:
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -302,6 +322,15 @@ def export_pytorch(
logger.info(f"Using framework PyTorch: {torch.__version__}")
output = Path(output)

if stateful:
# Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect
# both of them are applied to demonstrate the best performance.
# TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation.
model = patch_model_with_bettertransformer(model)
# TODO: Consider unpatching model after export is done in the end of this function.
# Now it is left as-is because the model is not expected to be used after call export_pytorch, and
# this function is one of the _internal_ steps in a bigger model conversion pipeline.

with torch.no_grad():
model.config.torchscript = False
model.config.return_dict = True
Expand Down Expand Up @@ -380,6 +409,14 @@ def ts_patched_forward(*args, **kwargs):
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
if patch_model_forward:
model.forward = orig_forward
if stateful:
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
logger.warn(
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
)
return export_pytorch_via_onnx(
model,
config,
Expand Down Expand Up @@ -411,6 +448,10 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()

if stateful:
patch_stateful(model.config, ov_model)

_save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio)
clear_class_registry()
del model
Expand All @@ -430,6 +471,7 @@ def export_models(
model_kwargs: Optional[Dict[str, Any]] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[int] = None,
stateful: bool = True,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand All @@ -451,6 +493,8 @@ def export_models(
Compression ratio between primary and backup precision (only relevant to INT4).
model_kwargs (Optional[Dict[str, Any]], optional):
Additional kwargs for model export.
stateful (`bool`, defaults to `True`)
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
Raises:
ValueError: if custom names set not equal of number of models
Expand Down Expand Up @@ -481,6 +525,7 @@ def export_models(
model_kwargs=model_kwargs,
compression_option=compression_option,
compression_ratio=compression_ratio,
stateful=stateful,
)
)

Expand Down
39 changes: 39 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging as log

from optimum.intel.utils.import_utils import is_torch_version


def patch_model_with_bettertransformer(model):
if is_torch_version("<", "2.0"):
log.warn(
"integration Scaled Dot Product Attention optimization supported only with torch > 2.0."
"Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
"It is recommended to upgrade PyTorch version for using stateful model or use stateful=False"
)
# model already has required SDPA implementation
if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
return model
try:
model = model.to_bettertransformer()
except Exception as e:
log.warn(
f"Cannot apply model.to_bettertransformer because of the exception:\n{e}."
" Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
)
return model

return model
Loading

0 comments on commit 7f236c2

Please sign in to comment.