From e57baaca2d0f566c8daa9fa027da07e58fe11436 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 16 Oct 2023 19:31:34 +0400 Subject: [PATCH] Add openvino export configs and support chatglm --- optimum/exporters/openvino/__init__.py | 4 + optimum/exporters/openvino/__main__.py | 143 +++++++++++++++--- optimum/exporters/openvino/base.py | 25 +++ .../openvino/dummy_input_generators.py | 61 ++++++++ optimum/exporters/openvino/model_configs.py | 91 +++++++++++ .../exporters/openvino/normalized_configs.py | 9 ++ optimum/intel/openvino/modeling_decoder.py | 31 +++- 7 files changed, 337 insertions(+), 27 deletions(-) create mode 100644 optimum/exporters/openvino/base.py create mode 100644 optimum/exporters/openvino/dummy_input_generators.py create mode 100644 optimum/exporters/openvino/model_configs.py create mode 100644 optimum/exporters/openvino/normalized_configs.py diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py index d87d8dda9e..f21ca7e595 100644 --- a/optimum/exporters/openvino/__init__.py +++ b/optimum/exporters/openvino/__init__.py @@ -1,5 +1,9 @@ from .__main__ import main_export +from .base import init_model_configs from .convert import export, export_models, export_pytorch_via_onnx +from .model_configs import * +init_model_configs() + __all__ = ["main_export", "export", "export_models"] diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 782aa0bc0d..52c72944e8 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -15,14 +15,20 @@ import logging import os from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoTokenizer from optimum.exporters import TasksManager -from optimum.exporters.onnx import __main__ as optimum_main from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast +from optimum.exporters.onnx.utils import ( + _get_submodels_for_export_encoder_decoder, + _get_submodels_for_export_stable_diffusion, + get_encoder_decoder_models_for_export, + get_sam_models_for_export, + get_stable_diffusion_models_for_export, +) from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors @@ -31,6 +37,10 @@ from .convert import export_models +if TYPE_CHECKING: + from transformers import PreTrainedModel, TFPreTrainedModel + + OV_XML_FILE_NAME = "openvino_model.xml" _MAX_UNCOMPRESSED_SIZE = 1e9 @@ -38,6 +48,102 @@ logger = logging.getLogger(__name__) +def _get_submodels_and_export_configs( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + task: str, + custom_onnx_configs: Dict, + custom_architecture: bool, + _variant: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + fn_get_submodels: Optional[Callable] = None, + preprocessors: Optional[List[Any]] = None, + no_position_ids: bool = False, +): + is_stable_diffusion = "stable-diffusion" in task + if not custom_architecture: + if is_stable_diffusion: + onnx_config = None + models_and_onnx_configs = get_stable_diffusion_models_for_export( + model, int_dtype=int_dtype, float_dtype=float_dtype + ) + else: + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="openvino", task=task + ) + onnx_config_kwargs = {} + if task.startswith("text-generation") and no_position_ids: + onnx_config_kwargs["no_position_ids"] = no_position_ids + + onnx_config = onnx_config_constructor( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + **onnx_config_kwargs, + ) + + onnx_config.variant = _variant + all_variants = "\n".join( + [f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()] + ) + logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}") + + if model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS): + models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) + elif task.startswith("text-generation"): + model = patch_decoder_attention_mask(model) + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="openvino", task=task + ) + onnx_config = onnx_config_constructor(model.config) + models_and_onnx_configs = {"model": (model, onnx_config)} + elif model.config.model_type == "sam": + models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) + else: + models_and_onnx_configs = {"model": (model, onnx_config)} + + # When specifying custom ONNX configs for supported transformers architectures, we do + # not force to specify a custom ONNX config for each submodel. + for key, custom_onnx_config in custom_onnx_configs.items(): + models_and_onnx_configs[key] = (models_and_onnx_configs[key][0], custom_onnx_config) + else: + onnx_config = None + submodels_for_export = None + models_and_onnx_configs = {} + + if fn_get_submodels is not None: + submodels_for_export = fn_get_submodels(model) + else: + if is_stable_diffusion: + submodels_for_export = _get_submodels_for_export_stable_diffusion(model) + elif model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS): + submodels_for_export = _get_submodels_for_export_encoder_decoder( + model, use_past=task.endswith("-with-past") + ) + elif task.startswith("text-generation"): + model = patch_decoder_attention_mask(model) + models_and_onnx_configs = {"model": model} + else: + submodels_for_export = {"model": model} + + if submodels_for_export.keys() != custom_onnx_configs.keys(): + logger.error(f"ONNX custom configs for: {', '.join(custom_onnx_configs.keys())}") + logger.error(f"Submodels to export: {', '.join(submodels_for_export.keys())}") + raise ValueError( + "Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary." + ) + + for key, custom_onnx_config in custom_onnx_configs.items(): + models_and_onnx_configs[key] = (submodels_for_export[key], custom_onnx_config) + + # Default to the first ONNX config for stable-diffusion and custom architecture case. + if onnx_config is None: + onnx_config = next(iter(models_and_onnx_configs.values()))[1] + + return onnx_config, models_and_onnx_configs + + def main_export( model_name_or_path: str, output: Union[str, Path], @@ -183,7 +289,7 @@ def main_export( f"If you want to support {model_type} please propose a PR or open up an issue." ) if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( - task, exporter="onnx" + task, exporter="openvino" ): custom_architecture = True @@ -200,7 +306,7 @@ def main_export( if ( not custom_architecture and not is_stable_diffusion - and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") + and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino") ): if original_task == "auto": # Make -with-past the default if --task was not explicitely specified task = task + "-with-past" @@ -222,24 +328,15 @@ def main_export( preprocessors = maybe_load_preprocessors( model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code ) - if not task.startswith("text-generation"): - onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( - model=model, - task=task, - monolith=False, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, - custom_architecture=custom_architecture, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - _variant="default", - ) - else: - # TODO : ModelPatcher will be added in next optimum release - model = patch_decoder_attention_mask(model) - - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config) - models_and_onnx_configs = {"model": (model, onnx_config)} + onnx_config, models_and_onnx_configs = _get_submodels_and_export_configs( + model=model, + task=task, + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_architecture=custom_architecture, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + _variant="default", + ) if int8 is None: int8 = False @@ -276,7 +373,7 @@ def main_export( generation_config = getattr(model, "generation_config", None) if generation_config is not None: generation_config.save_pretrained(output) - maybe_save_preprocessors(model_name_or_path, output) + maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) if model.config.is_encoder_decoder and task.startswith("text-generation"): raise ValueError( diff --git a/optimum/exporters/openvino/base.py b/optimum/exporters/openvino/base.py new file mode 100644 index 0000000000..2de28432c8 --- /dev/null +++ b/optimum/exporters/openvino/base.py @@ -0,0 +1,25 @@ +from copy import deepcopy +from typing import Callable, Type + +from optimum.exporters.tasks import TasksManager +from optimum.utils.normalized_config import NormalizedConfigManager + + +def init_model_configs(): + suppored_models = TasksManager._SUPPORTED_MODEL_TYPE + for model, export_configs in suppored_models.items(): + if "onnx" not in export_configs: + continue + TasksManager._SUPPORTED_MODEL_TYPE[model]["openvino"] = deepcopy( + TasksManager._SUPPORTED_MODEL_TYPE[model]["onnx"] + ) + + +def register_normalized_config(model_type: str) -> Callable[[Type], Type]: + def decorator(config_cls: Type) -> Type: + if model_type in NormalizedConfigManager._conf: + return config_cls + NormalizedConfigManager._conf[model_type] = config_cls + return config_cls + + return decorator diff --git a/optimum/exporters/openvino/dummy_input_generators.py b/optimum/exporters/openvino/dummy_input_generators.py new file mode 100644 index 0000000000..219b7193cf --- /dev/null +++ b/optimum/exporters/openvino/dummy_input_generators.py @@ -0,0 +1,61 @@ +from typing import Optional, Tuple + +from optimum.utils import ( + DEFAULT_DUMMY_SHAPES, + DummyPastKeyValuesGenerator, + DummyTextInputGenerator, + NormalizedTextConfig, +) + + +class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator): + SUPPORTED_INPUT_NAMES = { + "input_ids", + "attention_mask", + "token_type_ids", + "position_ids", + } + + +class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.multi_query_group_num = normalized_config.multi_query_group_num + self.head_dim = self.hidden_size // self.num_attention_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_key_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + past_value_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py new file mode 100644 index 0000000000..eeec30d75e --- /dev/null +++ b/optimum/exporters/openvino/model_configs.py @@ -0,0 +1,91 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Type + +from optimum.exporters.onnx import TextDecoderOnnxConfig +from optimum.exporters.tasks import TasksManager, make_backend_config_constructor_for_task + +from .dummy_input_generators import ChatGLM2DummyPastKeyValuesGenerator, ChatGLN2DummyTextInputGenerator +from .normalized_configs import ChatGLM2NormalizedConfig + + +def create_register(overwrite_existing: bool = False): + def wrapper(model_type: str, *supported_tasks: str) -> Callable[[Type], Type]: + def decorator(config_cls: Type) -> Type: + mapping = TasksManager._SUPPORTED_MODEL_TYPE.get(model_type, {}) + mapping_backend = mapping.get("openvino", {}) + for task in supported_tasks: + normalized_task = task + if "-with-past" in task: + normalized_task = task.split("-with-past")[0] + if normalized_task not in TasksManager.get_all_tasks(): + known_tasks = ", ".join(TasksManager.get_all_tasks()) + raise ValueError( + f'The TasksManager does not know the task called "{task}", known tasks: {known_tasks}.' + ) + if not overwrite_existing and task in mapping_backend: + continue + mapping_backend[task] = make_backend_config_constructor_for_task(config_cls, task) + mapping["openvino"] = mapping_backend + TasksManager._SUPPORTED_MODEL_TYPE[model_type] = mapping + return config_cls + + return decorator + + return wrapper + + +register_in_tasks_manager = create_register(True) + + +@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"]) +class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = ChatGLM2NormalizedConfig + DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLN2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator + no_position_ids = False + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = super().inputs + common_inputs.pop("attention_mask") + if not self.no_position_ids and self.task == "text-generation": + common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + + return common_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): + The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name} diff --git a/optimum/exporters/openvino/normalized_configs.py b/optimum/exporters/openvino/normalized_configs.py new file mode 100644 index 0000000000..c50cf11741 --- /dev/null +++ b/optimum/exporters/openvino/normalized_configs.py @@ -0,0 +1,9 @@ +from optimum.utils import NormalizedTextConfig + +from .base import register_normalized_config + + +@register_normalized_config("chatglm") +class ChatGLM2NormalizedConfig(NormalizedTextConfig): + NUM_LAYERS = "num_layers" + VOCAB_SIZE = "padded_vocab_size" diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 68d737fe74..9e3262ac92 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -16,7 +16,7 @@ import os from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import numpy as np import openvino @@ -35,6 +35,10 @@ from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE +if TYPE_CHECKING: + pass + + if is_transformers_version("<", "4.25.0"): from transformers.generation_utils import GenerationMixin else: @@ -269,7 +273,9 @@ def _reshape( shapes[inputs][0] = -1 input_name = inputs.get_any_name() if input_name.startswith("past_key_values"): - if len(inputs.partial_shape) == 3 and input_name.endswith("value"): + if ( + len(inputs.partial_shape) == 3 and input_name.endswith("value") + ) or self.config.model_type == "chatglm": shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 @@ -312,6 +318,7 @@ def forward( input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: self.compile() @@ -345,6 +352,11 @@ def forward( for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) shape = model_inputs.get_partial_shape() + if self.config.model_type == "chatglm": + shape[0] = 0 + shape[1] = shape_input_ids[0] * num_attention_heads + inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) + continue shape[0] = shape_input_ids[0] * num_attention_heads if shape[2].is_dynamic: shape[2] = 0 @@ -358,6 +370,8 @@ def forward( if "attention_mask" in self.input_names and attention_mask is not None: inputs["attention_mask"] = np.array(attention_mask) + if "position_ids" in self.input_names and position_ids is not None: + inputs["position_ids"] = position_ids # Run inference self.request.start_async(inputs, shared_memory=True) self.request.wait() @@ -385,12 +399,21 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), + "position_ids": position_ids, + "attention_mask": attention_mask, "token_type_ids": None, }