Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openvino export configs and support chatglm #454

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .__main__ import main_export
from .base import init_model_configs
from .convert import export, export_models, export_pytorch_via_onnx
from .model_configs import *


init_model_configs()

__all__ = ["main_export", "export", "export_models"]
145 changes: 122 additions & 23 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer

from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.exporters.onnx.utils import (
_get_submodels_for_export_encoder_decoder,
_get_submodels_for_export_stable_diffusion,
get_encoder_decoder_models_for_export,
get_sam_models_for_export,
get_stable_diffusion_models_for_export,
)
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

Expand All @@ -31,13 +37,115 @@
from .convert import export_models


if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel


OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_SIZE = 1e9

logger = logging.getLogger(__name__)


def _get_submodels_and_export_configs(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
task: str,
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
if is_stable_diffusion:
onnx_config = None
models_and_onnx_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype
)
else:
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I would prefer we keep the onnx config so that we don't duplicate code as I'm unsure why we would need it for now, is there a specific reason for that @eaidova ?

Copy link
Collaborator Author

@eaidova eaidova Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main idea of this pr to allow override configs for onnx for openvino.
I see 2 reasons for that:

  1. Adding configurations for supporting new models now requires to guiranree that model can be exported in onnx (and possibly even provide pipeline for running it with ORT). However, we do not use export to onnx as default path for export models to openvino anymore. The set of supported operations from torch in openvino can be different from set of ops supported to torch-onnx export. For example, we successfully convert Tapas model family to openvino (for example https://huggingface.co/google/tapas-tiny-finetuned-wtq), but onnx export failed with unsupported aten::scatter_reduce
  2. Unblocking different optimizations of existing models that specific for openvino. For the same reasons that we do not export models to onnx and has own inference pipelines, in some time it may happens that onnx and openvino export paths can become different and require different configurations (for example if I want to merge 2 decoders in seq2seq models in the way how it will be convenient for openvino or in some other our plans we want to try export caulsallm models with including beam search inside models). Some simple example, where onnx configuration is not perfectly fit for us, text-generation-with-past, if I understand, for onnx default path is a merged model, while we use model-with-past. ONNX configs for a model with past fill input_ids with only 1 token for this case lead to some exporting model troubles, that require models patching and consider each new case as separated (in the latest optimum, created patchers per model, now there is just function that check model type and apply the patch on pytorch model), but it can be avoided in majority cases if input_ids will have 2 tokens instead of 1. Also it uses only dynamic batch and static sequence len (=1) for this case, that lead to extra model reshaping before loading to make sequence len dynamic for this input.
  3. Simplification of new models enabling flow. We really like optimum and all its features for its smooth user experience (thank you very much for everything what you do) and recommend it to our openvino users as the main path for running inference for HF models using openvino, so we are very interested in extension of supported models and having the latest trading models available running with openvino. Support of cli and API for export models open the door for converting everything that supported in optimum directly in openvino (even if we do not have some OVModelForXXX classes) directly to openvino and it is the great step for us. But now, it is not enough to just install optimum-intel from git for get the latest available models from optimum side, it requires also install optimum for git. But there is no guarantee that they are synchronized. Like you already highlighted about changes with position_ids for example or another thing what I recently found trying to run mistral model that now for some models it is not enough to specify only with_past=True for getting the model with past in inputs and outputs in the same time, additional with_past_in_inputs flag should be passed in config. We need to wait the next official realse for aligning and getting new models supported that maybe non-convinient for us.

We do not duplicate code, (for majority cases configs mapping filled in runtime, just reusing the same onnx config, but if we have own one for openvino specifc, we will use own) just allow overriding some export configs if it is applicable.

)
onnx_config_kwargs = {}
if task.startswith("text-generation") and no_position_ids:
onnx_config_kwargs["no_position_ids"] = no_position_ids

onnx_config = onnx_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
**onnx_config_kwargs,
)

onnx_config.variant = _variant
all_variants = "\n".join(
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

if model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation"):
model = patch_decoder_attention_mask(model)
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task
)
onnx_config = onnx_config_constructor(model.config)
if onnx_config.use_past:
onnx_config.use_past_in_inputs = True
models_and_onnx_configs = {"model": (model, onnx_config)}
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

# When specifying custom ONNX configs for supported transformers architectures, we do
# not force to specify a custom ONNX config for each submodel.
for key, custom_onnx_config in custom_onnx_configs.items():
models_and_onnx_configs[key] = (models_and_onnx_configs[key][0], custom_onnx_config)
else:
onnx_config = None
submodels_for_export = None
models_and_onnx_configs = {}

if fn_get_submodels is not None:
submodels_for_export = fn_get_submodels(model)
else:
if is_stable_diffusion:
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
elif model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS):
submodels_for_export = _get_submodels_for_export_encoder_decoder(
model, use_past=task.endswith("-with-past")
)
elif task.startswith("text-generation"):
model = patch_decoder_attention_mask(model)
models_and_onnx_configs = {"model": model}
else:
submodels_for_export = {"model": model}

if submodels_for_export.keys() != custom_onnx_configs.keys():
logger.error(f"ONNX custom configs for: {', '.join(custom_onnx_configs.keys())}")
logger.error(f"Submodels to export: {', '.join(submodels_for_export.keys())}")
raise ValueError(
"Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary."
)

for key, custom_onnx_config in custom_onnx_configs.items():
models_and_onnx_configs[key] = (submodels_for_export[key], custom_onnx_config)

# Default to the first ONNX config for stable-diffusion and custom architecture case.
if onnx_config is None:
onnx_config = next(iter(models_and_onnx_configs.values()))[1]

return onnx_config, models_and_onnx_configs


def main_export(
model_name_or_path: str,
output: Union[str, Path],
Expand Down Expand Up @@ -183,7 +291,7 @@ def main_export(
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
task, exporter="onnx"
task, exporter="openvino"
):
custom_architecture = True

Expand All @@ -200,7 +308,7 @@ def main_export(
if (
not custom_architecture
and not is_stable_diffusion
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino")
):
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
task = task + "-with-past"
Expand All @@ -222,24 +330,15 @@ def main_export(
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
)
else:
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
onnx_config, models_and_onnx_configs = _get_submodels_and_export_configs(
model=model,
task=task,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
)

if int8 is None:
int8 = False
Expand Down Expand Up @@ -276,7 +375,7 @@ def main_export(
generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)

if model.config.is_encoder_decoder and task.startswith("text-generation"):
raise ValueError(
Expand Down
38 changes: 38 additions & 0 deletions optimum/exporters/openvino/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Callable, Type

from optimum.exporters.tasks import TasksManager
from optimum.utils.normalized_config import NormalizedConfigManager


def init_model_configs():
suppored_models = TasksManager._SUPPORTED_MODEL_TYPE
for model, export_configs in suppored_models.items():
if "onnx" not in export_configs:
continue
TasksManager._SUPPORTED_MODEL_TYPE[model]["openvino"] = deepcopy(
TasksManager._SUPPORTED_MODEL_TYPE[model]["onnx"]
)


def register_normalized_config(model_type: str) -> Callable[[Type], Type]:
def decorator(config_cls: Type) -> Type:
if model_type in NormalizedConfigManager._conf:
return config_cls
NormalizedConfigManager._conf[model_type] = config_cls
return config_cls

return decorator
87 changes: 87 additions & 0 deletions optimum/exporters/openvino/dummy_input_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be moved to optimum/intel/utils/input_generators.py

from typing import Optional, Tuple

import torch

from optimum.utils import (
DEFAULT_DUMMY_SHAPES,
DummyPastKeyValuesGenerator,
DummyTextInputGenerator,
NormalizedTextConfig,
)


class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = {
"input_ids",
"attention_mask",
"token_type_ids",
"position_ids",
}

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input = super().generate(input_name, framework, int_dtype, float_dtype)
if input_name == "attention_mask":
input = torch.ones((input.shape[0], input.shape[1] + 1), dtype=input.dtype)
# input[0] = 0
if input_name == "position_ids":
input = torch.range(0, input.shape[1] + 1, dtype=input.dtype).repeat(1, 1)
# input[0] = 0
return input


class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.multi_query_group_num = normalized_config.multi_query_group_num
self.head_dim = self.hidden_size // self.num_attention_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
past_value_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
Loading