Skip to content

Commit

Permalink
add type hints and update doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 21, 2023
1 parent 70e2b01 commit e4b5072
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 19 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def main_export(
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`Optional[bool]`) - Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def export(
Compression ratio between primary and backup precision (only relevant to INT4).
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the exporter.
stateful (`Optional[bool]`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -238,6 +240,8 @@ def export_pytorch_via_onnx(
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`Optional[bool]`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Returns:
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -296,6 +300,13 @@ def export_pytorch(
If specified, allows to use specific shapes for the example input provided to the exporter.
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
Additional kwargs for model export
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`Optional[bool]`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Returns:
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -475,6 +486,8 @@ def export_models(
Compression ratio between primary and backup precision (only relevant to INT4).
model_kwargs (Optional[Dict[str, Any]], optional):
Additional kwargs for model export.
stateful (`Optional[bool]`)
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Raises:
ValueError: if custom names set not equal of number of models
Expand Down
110 changes: 95 additions & 15 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging as log
from typing import List

import numpy as np
from transformers import PretrainedConfig

import openvino as ov
from openvino.runtime import opset13
Expand All @@ -22,29 +25,67 @@


def model_has_input_output_name(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input or output name
Parameters:
ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology?
name (str):
name of input or output
Returns:
True if input or output with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])


def model_has_input(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input name
Parameters:
ov_model (ov.Model):
opennvino model
name (str):
name of input
Returns:
True if input with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs], [])


def model_has_cache_reorder(ov_model):
def model_has_cache_reorder(ov_model: ov.Model):
return model_has_input(ov_model, "beam_idx")


def model_has_state(ov_model):
def model_has_state(ov_model: ov.Model):
# TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
return len(ov_model.get_sinks()) > 0


def fuse_cache_reorder(ov_model: ov.Model, not_kv_inputs, key_value_input_names, gather_dim: int):
"""Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
def fuse_cache_reorder(
ov_model: ov.Model, not_kv_inputs: List[str], key_value_input_names: List[str], gather_dim: int
):
"""
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
Should be run before make_stateful. Implements optimumum's _reorder_cache
inside the model in the beginning of each iteration.
Gather works along given gather_dim dimension that may vary from model to model.
KV-cache inputs are identified based on names in key_value_input_names.
Append the new beam_idx parameter to not_kv_inputs.
Parameters:
ov_model (`ov.Model`):
openvino model for processing
not_kv_inputs (`List[str]`):
list of input nodes in model that not related to past key values
key_value_input_names (`List[str]`):
list of names for key value input layers
gather_dim (int):
dimension for gathering cache during reorder pass
"""

assert not model_has_input_output_name(ov_model, "beam_idx")
Expand All @@ -63,8 +104,16 @@ def fuse_cache_reorder(ov_model: ov.Model, not_kv_inputs, key_value_input_names,
ov_model.validate_nodes_and_infer_types()


def build_state_initializer(ov_model: ov.Model, batch_dim):
"""Build initialization ShapeOf Expression for all ReadValue ops"""
def build_state_initializer(ov_model: ov.Model, batch_dim: int):
"""
Build initialization ShapeOf Expression for all ReadValue ops
Parameters:
ov_model (ov.Model):
openvino model
batch_dim (int):
index of dimension corresponding to batch size
"""
input_ids = ov_model.input("input_ids")
batch = opset13.gather(opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0))
for op in ov_model.get_ops():
Expand All @@ -80,14 +129,32 @@ def build_state_initializer(ov_model: ov.Model, batch_dim):

def make_stateful(
ov_model: ov.Model,
not_kv_inputs,
key_value_input_names,
key_value_output_names,
batch_dim,
num_attention_heads,
num_beams_and_batch=None,
not_kv_inputs: List[str],
key_value_input_names: List[str],
key_value_output_names: List[str],
batch_dim: int,
num_attention_heads: int,
num_beams_and_batch: int = None,
):
"""Hides kv-cache inputs and outputs inside the model as variables."""
"""
Hides kv-cache inputs and outputs inside the model as variables.
Parameters:
ov_model (ov.Model):
openvino model
not_kv_inputs (`List[str]`):
list of input nodes in model that not related to past key values
key_value_input_names (`List[str]`):
list of names for key value input layers
key_value_output_names (`List[str]`):
list of names for key value input layers
batch_dim (int):
index of batch dimension in key value layers
num_attention_heads (int):
number of attention heads for batch dimension initialization
num_beams_an_batch (int):
precalculated number of beams and batch for shapes initialization
"""
from openvino._offline_transformations import apply_make_stateful_transformation

input_output_map = {}
Expand All @@ -101,7 +168,7 @@ def make_stateful(
shape[0] = num_beams_and_batch
input.get_node().set_partial_shape(shape)
else:
print(f"[ WARNING ] Rank of {input.get_any_name()} input of the model is not 2, batch size is not set")
log.warn(f"Rank of {input.get_any_name()} input of the model is not 2, batch size is not set")

for kv_name_pair in zip(key_value_input_names, key_value_output_names):
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
Expand All @@ -121,13 +188,26 @@ def make_stateful(


def raise_if_openvino_is_too_old():
"""
Check openvino version and raise error if it does not support stateful models
"""
if is_openvino_version("<", "2023.3"):
raise ValueError(
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0."
)


def patch_stateful(config, ov_model):
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
"""
Apply make stateful transofrmation to model fo hiding key values inputs inside model.
Select transformation parameters based on model architecture
Parameters:
config (`PretrainedConfig`):
model pretrained config
ov_model (`ov.Model`):
openvino model
"""
raise_if_openvino_is_too_old()

key_value_input_names = [
Expand Down
14 changes: 10 additions & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@
from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME
from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder
from optimum.intel.openvino.modeling_timm import TimmImageProcessor
from optimum.intel.utils.import_utils import is_openvino_version
from optimum.utils import (
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_UNET_SUBFOLDER,
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
)
from optimum.utils.testing_utils import require_diffusers
from optimum.intel.utils.import_utils import is_openvino_version


TENSOR_ALIAS_TO_TYPE = {
Expand Down Expand Up @@ -488,7 +488,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"pegasus",
)
GENERATION_LENGTH = 100
IS_SUPPORT_STATEFUL = is_openvino_version(">=" , "2023.3")
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down Expand Up @@ -646,14 +646,20 @@ def test_stateful(self, model_arch):
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
position_ids = position_ids[:, -1:] + 1
pkv = ov_outputs.past_key_values
ov_outputs = ov_model(input_ids=next_token, position_ids=position_ids, attention_mask=attention_mask, past_key_values=pkv)
ov_outputs = ov_model(
input_ids=next_token, position_ids=position_ids, attention_mask=attention_mask, past_key_values=pkv
)
self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
with torch.no_grad():
transformers_outputs = transformers_model(input_ids=next_token, attention_mask=attention_mask, past_key_values=transformers_outputs.past_key_values)
transformers_outputs = transformers_model(
input_ids=next_token,
attention_mask=attention_mask,
past_key_values=transformers_outputs.past_key_values,
)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))

del transformers_model
Expand Down

0 comments on commit e4b5072

Please sign in to comment.