diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 95ecea1213..be97e665b9 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -92,6 +92,14 @@ def parse_args_openvino(parser: "ArgumentParser"): "precision (by default 20% in INT8). This helps to achieve better accuracy after weight quantization." ), ) + optional_group.add_argument( + "--no-stateful", + action="store_true", + help=( + "Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. " + "In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs.", + ), + ) class OVExportCommand(BaseOptimumCLICommand): @@ -138,6 +146,7 @@ def run(self): trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, compression_option=self.args.weight_format, - compression_ratio=self.args.ratio + compression_ratio=self.args.ratio, + stateful=not self.args.no_stateful, # **input_shapes, ) diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py index d87d8dda9e..f94e4ba5c5 100644 --- a/optimum/exporters/openvino/__init__.py +++ b/optimum/exporters/openvino/__init__.py @@ -1,5 +1,6 @@ from .__main__ import main_export from .convert import export, export_models, export_pytorch_via_onnx +from .stateful import patch_stateful, raise_if_openvino_is_too_old __all__ = ["main_export", "export", "export_models"] diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 54fe1193e5..73fef5c0bf 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -65,6 +65,7 @@ def main_export( fn_get_submodels: Optional[Callable] = None, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, + stateful: Optional[bool] = True, **kwargs_shapes, ): """ @@ -124,6 +125,8 @@ def main_export( `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). + stateful (`Optional[bool]`, defaults to `True`): + Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -373,6 +376,7 @@ class StoreAttr(object): device=device, compression_option=compression_option, compression_ratio=compression_ratio, + stateful=stateful, model_kwargs=model_kwargs, ) diff --git a/optimum/exporters/openvino/better_transformer_patch.py b/optimum/exporters/openvino/better_transformer_patch.py new file mode 100644 index 0000000000..8cc98185f8 --- /dev/null +++ b/optimum/exporters/openvino/better_transformer_patch.py @@ -0,0 +1,39 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging as log + +from optimum.intel.utils.import_utils import is_torch_version + + +def patch_model_with_bettertransformer(model): + if is_torch_version("<", "2.0"): + log.warn( + "integration Scaled Dot Product Attention optimization supported only with torch > 2.0." + "Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" + "It is recommended to upgrade PyTorch version for using stateful model or use stateful=Flase" + ) + # model already has required SDPA implementation + if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": + return model + try: + model = model.to_bettertransformer() + except Exception as e: + log.warn( + f"Cannot apply model.to_bettertransformer because of the exception:\n{e}." + " Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" + ) + return model + + return model diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 56c5a10e5d..1ee72dba1f 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -32,6 +32,8 @@ from optimum.utils import is_diffusers_available from ...intel.utils.import_utils import is_nncf_available, is_optimum_version +from .better_transformer_patch import patch_model_with_bettertransformer +from .stateful import patch_stateful, raise_if_openvino_is_too_old from .utils import ( OV_XML_FILE_NAME, clear_class_registry, @@ -102,6 +104,7 @@ def export( model_kwargs: Optional[Dict[str, Any]] = None, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, + stateful: bool = True, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. @@ -125,6 +128,8 @@ def export( Compression ratio between primary and backup precision (only relevant to INT4). input_shapes (`Optional[Dict]`, defaults to `None`): If specified, allows to use specific shapes for the example input provided to the exporter. + stateful (`Optional[bool]`, defaults to `True`): + Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from @@ -150,6 +155,7 @@ def export( compression_option=compression_option, compression_ratio=compression_ratio, model_kwargs=model_kwargs, + stateful=stateful, ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): @@ -271,6 +277,7 @@ def export_pytorch( model_kwargs: Optional[Dict[str, Any]] = None, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, + stateful: bool = True, ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an OpenVINO Intermediate Representation. @@ -291,6 +298,13 @@ def export_pytorch( If specified, allows to use specific shapes for the example input provided to the exporter. model_kwargs (optional[Dict[str, Any]], defaults to `None`): Additional kwargs for model export + compression_option (`Optional[str]`, defaults to `None`): + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, + `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point. + compression_ratio (`Optional[float]`, defaults to `None`): + Compression ratio between primary and backup precision (only relevant to INT4). + stateful (`Optional[bool]`, defaults to `True`): + Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs Returns: `Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from @@ -302,6 +316,15 @@ def export_pytorch( logger.info(f"Using framework PyTorch: {torch.__version__}") output = Path(output) + if stateful: + # Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect + # both of them are applied to demonstrate the best performance. + # TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation. + model = patch_model_with_bettertransformer(model) + # TODO: Consider unpatching model after export is done in the end of this function. + # Now it is left as-is because the model is not expected to be used after call export_pytorch, and + # this function is one of the _internal_ steps in a bigger model conversion pipeline. + with torch.no_grad(): model.config.torchscript = False model.config.return_dict = True @@ -380,6 +403,13 @@ def ts_patched_forward(*args, **kwargs): logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") if patch_model_forward: model.forward = orig_forward + if stateful: + # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly + # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation + print( + "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. Stateless model will be exported instead. " + "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path." + ) return export_pytorch_via_onnx( model, config, @@ -411,6 +441,10 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() + + if stateful: + patch_stateful(model.config, ov_model) + _save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio) clear_class_registry() del model @@ -430,6 +464,7 @@ def export_models( model_kwargs: Optional[Dict[str, Any]] = None, compression_option: Optional[str] = None, compression_ratio: Optional[int] = None, + stateful: bool = True, ) -> Tuple[List[List[str]], List[List[str]]]: """ Export the models to OpenVINO IR format @@ -451,6 +486,8 @@ def export_models( Compression ratio between primary and backup precision (only relevant to INT4). model_kwargs (Optional[Dict[str, Any]], optional): Additional kwargs for model export. + stateful (`Optional[bool]`, defaults to `True`) + Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs Raises: ValueError: if custom names set not equal of number of models @@ -458,6 +495,9 @@ def export_models( Returns: list of input_names and output_names from ONNX configuration """ + if stateful: + # This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used + raise_if_openvino_is_too_old() outputs = [] if output_names is not None and len(output_names) != len(models_and_onnx_configs): @@ -481,6 +521,7 @@ def export_models( model_kwargs=model_kwargs, compression_option=compression_option, compression_ratio=compression_ratio, + stateful=stateful, ) ) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py new file mode 100644 index 0000000000..baa49ddd62 --- /dev/null +++ b/optimum/exporters/openvino/stateful.py @@ -0,0 +1,230 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging as log +from typing import List + +import numpy as np +from transformers import PretrainedConfig + +import openvino as ov +from openvino.runtime import opset13 +from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version +from optimum.utils.normalized_config import NormalizedConfigManager + + +def model_has_state(ov_model: ov.Model): + # TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods + return len(ov_model.get_sinks()) > 0 + + +def model_has_input_output_name(ov_model: ov.Model, name: str): + """ + Helper function for checking that model has specified input or output name + + Parameters: + ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology? + name (str): + name of input or output + + Returns: + True if input or output with requested name exists else False + """ + return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) + + +def model_has_input(ov_model: ov.Model, name: str): + """ + Helper function for checking that model has specified input name + + Parameters: + ov_model (ov.Model): + opennvino model + name (str): + name of input + + Returns: + True if input with requested name exists else False + """ + return name in sum([list(t.get_names()) for t in ov_model.inputs], []) + + +def fuse_cache_reorder( + ov_model: ov.Model, not_kv_inputs: List[str], key_value_input_names: List[str], gather_dim: int +): + """ + Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly. + + Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model. + Should be run before make_stateful. Implements optimumum's _reorder_cache + inside the model in the beginning of each iteration. + Gather works along given gather_dim dimension that may vary from model to model. + KV-cache inputs are identified based on names in key_value_input_names. + Append the new beam_idx parameter to not_kv_inputs. + + Parameters: + ov_model (`ov.Model`): + openvino model for processing + not_kv_inputs (`List[str]`): + list of input nodes in model that not related to past key values + key_value_input_names (`List[str]`): + list of names for key value input layers + gather_dim (int): + dimension for gathering cache during reorder pass + """ + + assert not model_has_input_output_name(ov_model, "beam_idx") + input_batch = ov_model.input("input_ids").get_partial_shape()[0] + beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) + beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? + ov_model.add_parameters([beam_idx]) + not_kv_inputs.append(ov_model.inputs[-1]) + # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx + for input_name in key_value_input_names: + parameter_output_port = ov_model.input(input_name) + consumers = parameter_output_port.get_target_inputs() + gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) + for consumer in consumers: + consumer.replace_source_output(gather.output(0)) + ov_model.validate_nodes_and_infer_types() + + +def build_state_initializer(ov_model: ov.Model, batch_dim: int): + """ + Build initialization ShapeOf Expression for all ReadValue ops + + Parameters: + ov_model (ov.Model): + openvino model + batch_dim (int): + index of dimension corresponding to batch size + """ + input_ids = ov_model.input("input_ids") + batch = opset13.gather(opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0)) + for op in ov_model.get_ops(): + if op.get_type_name() == "ReadValue": + dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] + dims[batch_dim] = batch + dims = [opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim for dim in dims] + shape = opset13.concat(dims, axis=0) + broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) + op.set_arguments([broadcast]) + ov_model.validate_nodes_and_infer_types() + + +def make_stateful( + ov_model: ov.Model, + not_kv_inputs: List[str], + key_value_input_names: List[str], + key_value_output_names: List[str], + batch_dim: int, + num_attention_heads: int, + num_beams_and_batch: int = None, +): + """ + Hides kv-cache inputs and outputs inside the model as variables. + + Parameters: + ov_model (ov.Model): + openvino model + not_kv_inputs (`List[str]`): + list of input nodes in model that not related to past key values + key_value_input_names (`List[str]`): + list of names for key value input layers + key_value_output_names (`List[str]`): + list of names for key value input layers + batch_dim (int): + index of batch dimension in key value layers + num_attention_heads (int): + number of attention heads for batch dimension initialization + num_beams_an_batch (int): + precalculated number of beams and batch for shapes initialization + """ + from openvino._offline_transformations import apply_make_stateful_transformation + + input_output_map = {} + # TODO: Can we derive the dimensions from the model topology? + + if num_beams_and_batch is not None: + # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue + for input in not_kv_inputs: + shape = input.get_partial_shape() + if shape.rank.get_length() <= 2: # == 1 for beam_index + shape[0] = num_beams_and_batch + input.get_node().set_partial_shape(shape) + else: + log.warn(f"Rank of {input.get_any_name()} input of the model is not 2, batch size is not set") + + for kv_name_pair in zip(key_value_input_names, key_value_output_names): + input_output_map[kv_name_pair[0]] = kv_name_pair[1] + if num_beams_and_batch is not None: + input = ov_model.input(kv_name_pair[0]) + shape = input.get_partial_shape() + shape[batch_dim] = num_beams_and_batch * num_attention_heads + input.get_node().set_partial_shape(shape) + + if num_beams_and_batch is not None: + # Re-validation model if shapes are altered above + ov_model.validate_nodes_and_infer_types() + + apply_make_stateful_transformation(ov_model, input_output_map) + if num_beams_and_batch is None: + build_state_initializer(ov_model, batch_dim) + + +def raise_if_openvino_is_too_old(): + """ + Check openvino version and raise error if it does not support stateful models + """ + if is_openvino_version("<", "2023.3"): + raise ValueError( + f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0." + ) + + +def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): + """ + Apply make stateful transofrmation to model fo hiding key values inputs inside model. + Select transformation parameters based on model architecture + + Parameters: + config (`PretrainedConfig`): + model pretrained config + ov_model (`ov.Model`): + openvino model + """ + raise_if_openvino_is_too_old() + + key_value_input_names = [ + key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names()) + ] + key_value_output_names = [ + key.get_any_name() for key in ov_model.outputs if any("present" in key_name for key_name in key.get_names()) + ] + not_kv_inputs = [ + input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names()) + ] + + # By default, batch is the 0-th but chatglm uses 1-st dimension as batch + # TODO: Deduce from a model via ordinal reshape (?) and topology + batch_dim = 1 if config.model_type == "chatglm" else 0 + + fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) + + normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + num_attention_heads = normalized_config.num_attention_heads if config.model_type == "bloom" else 1 + + make_stateful( + ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None + ) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 14f8dbcafa..7fb367f063 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -29,7 +29,8 @@ from optimum.utils import NormalizedConfigManager -from ...exporters.openvino import main_export +from ...exporters.openvino import main_export, patch_stateful, raise_if_openvino_is_too_old +from ...exporters.openvino.stateful import model_has_state from ..utils.import_utils import is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel @@ -125,7 +126,10 @@ def __init__( self.is_dynamic = dynamic_shapes use_cache = kwargs.pop("use_cache", True) - self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) + stateful = kwargs.pop("stateful", True) + model_has_sinks = model_has_state(self.model) + self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks + self.stateful = model_has_sinks self.main_input_name = "input_ids" self.num_pkv = 2 self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) @@ -136,19 +140,34 @@ def __init__( self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) - if enable_compilation: - self.compile() - if use_cache ^ self.use_cache: + if self.stateful or stateful: + raise_if_openvino_is_too_old() + + def raise_error(model_prop, user_prop, name): raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " - f"Please load your current model with `use_cache={self.use_cache}` or export the original model " - f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " + f"`{name}` was set to `{user_prop}` but the loaded model only supports `{name}={model_prop}`. " + f"Please load your current model with `{name}={model_prop}` or export the original model " + f"once again with `{name}={user_prop}` when calling the `from_pretrained` method. " "To export your model, simply set `export=True`." ) + if stateful is not None and self.stateful and not stateful: + # We cannot transform stateful model to stateless + raise_error(self.stateful, stateful, "stateful") + + if not self.stateful and stateful: + # We can transform stateless model to stateful + self._make_stateful() + + if enable_compilation: + self.compile() + + if use_cache ^ self.use_cache: + raise_error(self.use_cache, use_cache, "use_cache") + def update_pkv_precision(self, force_fp32=False): - if not self.use_cache: + if not self.use_cache or self.stateful: return pkv_precision = Type.f32 @@ -276,6 +295,8 @@ def _reshape( shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 + elif input_name.startswith("beam_idx"): + shapes[inputs][0] = -1 else: shapes[inputs][1] = -1 model.reshape(shapes) @@ -290,6 +311,10 @@ def compile(self): super().compile() self.request = self.request.create_infer_request() + def _make_stateful(self): + patch_stateful(self.config, self.model) + self.stateful = True + @add_start_docstrings( """ @@ -319,49 +344,69 @@ def forward( **kwargs, ) -> CausalLMOutputWithPast: self.compile() - inputs = {} if self.use_cache and past_key_values is not None: input_ids = input_ids[:, -1:] + batch_size = input_ids.shape[0] + if self.config.model_type == "bloom": + batch_size *= self.normalized_config.num_attention_heads + inputs = {} past_len = 0 - if past_key_values is not None: - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: - past_len = past_key_values[0][1].shape[-2] - if self._pkv_precision == Type.bf16: - # numpy does not support bf16, pretending f16, should change to bf16 - past_key_values = tuple( - Tensor(past_key_value, past_key_value.shape, Type.bf16) - for pkv_per_layer in past_key_values - for past_key_value in pkv_per_layer - ) + if not self.stateful: + if past_key_values is not None: + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + past_len = past_key_values[0][1].shape[-2] + if self._pkv_precision == Type.bf16: + # numpy does not support bf16, pretending f16, should change to bf16 + past_key_values = tuple( + Tensor(past_key_value, past_key_value.shape, Type.bf16) + for pkv_per_layer in past_key_values + for past_key_value in pkv_per_layer + ) + else: + # Flatten the past_key_values + past_key_values = tuple( + past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer + ) else: - # Flatten the past_key_values - past_key_values = tuple( - past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer - ) - else: - past_len = past_key_values[0].shape[-2] - - # Add the past_key_values to the decoder inputs - inputs = dict(zip(self.key_value_input_names, past_key_values)) - - # Create empty past_key_values for decoder_with_past first generation step - elif self.use_cache: - batch_size = input_ids.shape[0] - if self.config.model_type == "bloom": - batch_size *= self.normalized_config.num_attention_heads - - for input_name in self.key_value_input_names: - model_inputs = self.model.input(input_name) - shape = model_inputs.get_partial_shape() - shape[0] = batch_size - if shape[2].is_dynamic: - shape[2] = 0 + past_len = past_key_values[0].shape[-2] + + # Add the past_key_values to the decoder inputs + inputs = dict(zip(self.key_value_input_names, past_key_values)) + + # Create empty past_key_values for decoder_with_past first generation step + elif self.use_cache: + for input_name in self.key_value_input_names: + model_inputs = self.model.input(input_name) + shape = model_inputs.get_partial_shape() + if self.config.model_type == "chatglm": + shape[0] = 0 + shape[1] = batch_size + else: + shape[0] = batch_size + if shape[2].is_dynamic: + shape[2] = 0 + else: + shape[1] = 0 + inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) + else: + # past_key_values are not used explicitly, instead they are handled inside the model + if past_key_values is None: + # Need a marker to differentiate the first generate iteration from the others in + # the first condition at the function beginning above. + # It should be something that is not None and it should be True when converted to Boolean. + past_key_values = ((),) + # This is the first iteration in a sequence, reset all states + if hasattr(self.request, "reset_state"): + self.request.reset_state() else: - shape[1] = 0 - inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) + for state in self.request.query_state(): + state.reset() + # Set initial value for the next beam_idx input that will be used at the current iteration + # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used + self.next_beam_idx = np.array(range(batch_size), dtype=int) inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed @@ -387,21 +432,25 @@ def forward( inputs["position_ids"] = position_ids + if hasattr(self, "next_beam_idx"): + inputs["beam_idx"] = self.next_beam_idx + # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) - if self.use_cache: - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) - past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: - # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) - past_key_values = tuple( - past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) - ) - else: - past_key_values = None + if not self.stateful: + if self.use_cache: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) + past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) + past_key_values = tuple( + past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) + ) + else: + past_key_values = None return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) @@ -428,18 +477,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg } # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - @staticmethod def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - return tuple( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values - ) + if self.stateful: + # TODO: Apply it differently based on model type + # TODO: At least for bloom we need to replicate values for each attention head + self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + return past_key_values + else: + return tuple( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values + ) def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" @@ -500,7 +554,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg use_cache = kwargs.get("use_cache", None) # only last token for input_ids if past is not None - if past_key_values: + if past_key_values and not self.stateful: # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) @@ -522,15 +576,23 @@ def _reorder_cache( [`~PreTrainedModel.beam_sample`] is called for bloom architecture. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) - reordered_past = tuple( - ( - np.take(layer_past[0], beam_idx, 0), - np.take(layer_past[1], beam_idx, 0), + if self.stateful: + beam_idx = np.array(beam_idx) + batch_size = beam_idx.shape[0] + indices = np.array(range(batch_size * self.normalized_config.num_attention_heads)) + indices = indices.reshape([batch_size, self.normalized_config.num_attention_heads]) + self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() + return past_key_values + else: + standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) + reordered_past = tuple( + ( + np.take(layer_past[0], beam_idx, 0), + np.take(layer_past[1], beam_idx, 0), + ) + for layer_past in standardized_past ) - for layer_past in standardized_past - ) - return self._convert_to_bloom_cache(reordered_past) + return self._convert_to_bloom_cache(reordered_past) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache @staticmethod @@ -602,8 +664,11 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache - @staticmethod def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: - return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) + if self.stateful: + self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + return past_key_values + else: + return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index afa5ff81dd..3f5d270da9 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -313,9 +313,11 @@ def start_async( inputs: Any = None, userdata: Any = None, share_inputs: bool = False, + *, + shared_memory: Any = None, ): data_cache.append(inputs) - self.request.infer(inputs, share_inputs) + self.request.infer(inputs, share_inputs, share_outputs=True, shared_memory=shared_memory) def wait(self): pass diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index f778bbfcbd..3f3fa6c55b 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -70,12 +70,17 @@ _openvino_version = "N/A" if _openvino_available: try: - _openvino_version = importlib_metadata.version("openvino") - except importlib_metadata.PackageNotFoundError: - try: - _openvino_version = importlib_metadata.version("openvino-nightly") - except importlib_metadata.PackageNotFoundError: - _openvino_available = False + from openvino.runtime import get_version + + version = get_version() + # avoid invalid format + if "-" in version: + major_version, dev_info = version.split("-", 1) + commit_id = dev_info.split("-")[0] + version = f"{major_version}-{commit_id}" + _openvino_version = version + except ImportError: + _openvino_available = False _nncf_available = importlib.util.find_spec("nncf") is not None diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index dc33b39f2a..ad32560282 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -73,6 +73,7 @@ from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder from optimum.intel.openvino.modeling_timm import TimmImageProcessor +from optimum.intel.utils.import_utils import is_openvino_version from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, @@ -487,6 +488,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "pegasus", ) GENERATION_LENGTH = 100 + IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -600,7 +602,7 @@ def test_default_filling_attention_mask(self): attention_mask = tokens.pop("attention_mask") outs_without_attn_mask = model_with_cache(**tokens) self.assertTrue(torch.allclose(outs.logits, outs_without_attn_mask.logits)) - input_ids = torch.argmax(outs.logits, dim=2) + input_ids = torch.argmax(outs.logits[:, -1:, :], dim=2) past_key_values = outs.past_key_values attention_mask = torch.ones((input_ids.shape[0], tokens.input_ids.shape[1] + 1), dtype=torch.long) outs_step2 = model_with_cache( @@ -611,6 +613,106 @@ def test_default_filling_attention_mask(self): del model_with_cache gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + def test_stateful(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True) + self.assertIsInstance(ov_model.config, PretrainedConfig) + self.assertTrue(ov_model.stateful) + self.assertTrue(ov_model.use_cache) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None + ) + position_ids = None + input_shape = tokens["input_ids"].shape + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + ov_outputs = ov_model(**tokens, position_ids=position_ids) + + self.assertTrue("logits" in ov_outputs) + self.assertIsInstance(ov_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in ov_outputs) + self.assertIsInstance(ov_outputs.past_key_values, tuple) + self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens) + # Compare tensor outputs + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + next_token = torch.argmax(ov_outputs.logits[..., -1:, :], dim=2) + attention_mask = torch.ones((input_shape[0], input_shape[1] + 1), dtype=torch.long) + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + position_ids = position_ids[:, -1:] + 1 + pkv = ov_outputs.past_key_values + ov_outputs = ov_model( + input_ids=next_token, position_ids=position_ids, attention_mask=attention_mask, past_key_values=pkv + ) + self.assertTrue("logits" in ov_outputs) + self.assertIsInstance(ov_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in ov_outputs) + self.assertIsInstance(ov_outputs.past_key_values, tuple) + self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + with torch.no_grad(): + transformers_outputs = transformers_model( + input_ids=next_token, + attention_mask=attention_mask, + past_key_values=transformers_outputs.past_key_values, + ) + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + + del transformers_model + del ov_model + gc.collect() + + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + def test_stateful_on_converted_model(self): + model_id = "vuiseng9/ov-gpt2-fp32-kv-cache" + # reference without state + loaded_model = OVModelForCausalLM.from_pretrained(model_id, stateful=False) + self.assertIsInstance(loaded_model.config, PretrainedConfig) + self.assertFalse(loaded_model.stateful) + self.assertTrue(loaded_model.use_cache) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + loaded_model_outputs = loaded_model(**tokens) + + # explicit stateful model specified during loading + loaded_stateful_model = OVModelForCausalLM.from_pretrained(model_id, stateful=True) + self.assertIsInstance(loaded_model.config, PretrainedConfig) + self.assertTrue(loaded_model.stateful) + self.assertTrue(loaded_model.use_cache) + loaded_stateful_model_outputs = loaded_stateful_model(**tokens) + self.assertTrue(torch.equal(loaded_model_outputs.logits, loaded_stateful_model_outputs.logits)) + self.assertTrue("past_key_values" in loaded_stateful_model_outputs) + self.assertIsInstance(loaded_stateful_model_outputs.past_key_values, tuple) + self.assertTrue( + len(loaded_stateful_model_outputs.past_key_values) == 1 + and len(loaded_stateful_model_outputs.past_key_values[0]) == 0 + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + loaded_stateful_model.save_pretrained(tmpdirname) + folder_contents = os.listdir(tmpdirname) + self.assertTrue(OV_XML_FILE_NAME in folder_contents) + self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents) + # implicit load stateful model from disk + model = OVModelForCausalLM.from_pretrained(tmpdirname) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + outputs = model(**tokens) + self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits)) + self.assertTrue("past_key_values" in outputs) + self.assertIsInstance(outputs.past_key_values, tuple) + self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0) + del loaded_model + del loaded_stateful_model + del model + gc.collect() + class OVModelForMaskedLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index a08da51aab..955a55723b 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -51,6 +51,7 @@ from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG +from optimum.intel.utils.import_utils import is_openvino_version from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8 _TASK_TO_DATASET = { @@ -166,6 +167,8 @@ class OVWeightCompressionTest(unittest.TestCase): (OVStableDiffusionXLPipeline, "stable-diffusion-xl"), ) + IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature @@ -239,6 +242,40 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i outputs = model(**tokens) self.assertTrue("logits" in outputs) + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS) + def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, expected_int8, expected_int4): + task = model_cls.export_feature + + with tempfile.TemporaryDirectory() as tmp_dir: + model_id = MODEL_NAMES[model_name] + transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer.quantize( + save_directory=tmp_dir, + weights_only=True, + quantization_config=OVConfig(compression={"type": "int4_sym_g128", "ratio": 0.8}), + ) + model = model_cls.from_pretrained(tmp_dir) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + _, num_int8, num_int4 = get_num_quantized_nodes(model) + self.assertEqual(expected_int8, num_int8) + self.assertEqual(expected_int4, num_int4) + + tokens = tokenizer("This is a sample input", return_tensors="pt") + outputs = model(**tokens) + + self.assertTrue("logits" in outputs) + self.assertTrue("past_key_values" in outputs) + self.assertIsInstance(outputs.past_key_values, tuple) + self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True) @@ -256,6 +293,20 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_ov_int8[i], num_int8) + @parameterized.expand((OVModelForCausalLM, "gpt2")) + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type): + model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=True) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + models = [model] + + expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] + for i, model in enumerate(models): + _, num_int8, _ = get_num_quantized_nodes(model) + self.assertEqual(expected_ov_int8[i], num_int8) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type): model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=False)