Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix openvino main export #439

Merged
merged 6 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


Expand Down Expand Up @@ -213,15 +214,24 @@ def main_export(
else:
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
_variant="default",
)

if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
_variant="default",
)
else:
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down Expand Up @@ -254,7 +264,7 @@ def main_export(
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)

files_subpaths = None
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()]
else:
# save the subcomponent configuration
for model_name in models_and_onnx_configs:
Expand Down
6 changes: 1 addition & 5 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(
ov_model,
output.parent / output,
compress_to_fp16=False,
)
save_model(ov_model, output.parent / output, compress_to_fp16=False)
return input_names, output_names, True


Expand Down
6 changes: 1 addition & 5 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,7 @@ def from_pretrained(
model = TimmForImageClassification.from_pretrained(model_id, **kwargs)
onnx_config = TimmOnnxConfig(model.config)

return cls._to_load(
model=model,
config=config,
onnx_config=onnx_config,
)
return cls._to_load(model=model, config=config, onnx_config=onnx_config)
else:
return super().from_pretrained(
model_id=model_id,
Expand Down
46 changes: 14 additions & 32 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from transformers.file_utils import add_start_docstrings

from optimum.exporters.onnx import OnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.modeling_base import OptimizedModel

from ...exporters.openvino import export
from ...exporters.openvino import export, main_export
from ..utils.import_utils import is_transformers_version
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME

Expand Down Expand Up @@ -240,42 +239,25 @@ def _from_transformers(
kwargs (`Dict`, *optional*):
kwargs will be passed to the model during initialization
"""
task = task or cls.export_feature

model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
model_type = model.config.model_type.replace("_", "-")

onnx_config_class = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model=model,
task=task,
model_name=model_id,
model_type=model_type,
)

onnx_config = onnx_config_class(model.config)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

return cls._to_load(
model=model,
config=config,
onnx_config=onnx_config,
use_auth_token=use_auth_token,
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task or cls.export_feature,
subfolder=subfolder,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, **kwargs)

@classmethod
def _to_load(
cls,
Expand Down
62 changes: 17 additions & 45 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from transformers import PretrainedConfig
from transformers.file_utils import add_start_docstrings

from optimum.exporters import TasksManager
from optimum.exporters.onnx import get_encoder_decoder_models_for_export

from ...exporters.openvino import export_models
from ...exporters.openvino import main_export
from ..utils.import_utils import is_transformers_version
from .modeling_base import OVBaseModel
from .utils import (
Expand Down Expand Up @@ -244,56 +241,31 @@ def _from_transformers(
kwargs (`Dict`, *optional*):
kwargs will be passed to the model during initialization
"""
task = task or cls.export_feature

save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config, use_past=use_cache)
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
encoder_file_name = os.path.join("encoder", OV_ENCODER_NAME)
decoder_file_name = os.path.join("decoder", OV_DECODER_NAME)
decoder_with_past_file_name = os.path.join("decoder_with_past", OV_DECODER_WITH_PAST_NAME)

output_names = [encoder_file_name, decoder_file_name]
if use_cache is True:
output_names.append(decoder_with_past_file_name)

export_models(
models_and_onnx_configs=models_and_onnx_configs,
opset=onnx_config.DEFAULT_ONNX_OPSET,
output_dir=save_dir_path,
output_names=output_names,
)
if task is None:
task = cls.export_feature

return cls._from_pretrained(
model_id=save_dir_path,
config=config,
use_cache=use_cache,
from_onnx=False,
use_auth_token=use_auth_token,
if use_cache:
task = task + "-with-past"

main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
subfolder=subfolder,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
encoder_file_name=encoder_file_name,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
**kwargs,
force_download=force_download,
trust_remote_code=trust_remote_code,
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)

def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
shapes = {}
for inputs in model.inputs:
Expand Down
49 changes: 18 additions & 31 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.exporters import TasksManager
from optimum.utils import NormalizedConfigManager

from ...exporters.openvino import export
from ...exporters.openvino import main_export
from ..utils.import_utils import is_transformers_version
from ..utils.modeling_utils import patch_decoder_attention_mask
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE

Expand Down Expand Up @@ -219,44 +217,33 @@ def _from_transformers(
f"This architecture : {config.model_type} was not validated, only :{', '.join(_SUPPORTED_ARCHITECTURES)} architectures were "
"validated, use at your own risk."
)
task = task or cls.export_feature
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
}
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
config.is_decoder = True
config.is_encoder_decoder = False
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config, use_past=use_cache)

# TODO : create ModelPatcher to patch each architecture
model = patch_decoder_attention_mask(model)
if task is None:
task = cls.export_feature

# Export the model to the OpenVINO IR format
export(model=model, config=onnx_config, output=save_dir_path / OV_XML_FILE_NAME)
if use_cache:
task = task + "-with-past"

return cls._from_pretrained(
model_id=save_dir_path,
config=config,
from_onnx=False,
use_auth_token=use_auth_token,
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
subfolder=subfolder,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=OV_XML_FILE_NAME,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
use_cache=use_cache,
**kwargs,
force_download=force_download,
trust_remote_code=trust_remote_code,
)

config.is_decoder = True
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)

def _reshape(
self,
model: openvino.runtime.Model,
Expand Down
8 changes: 2 additions & 6 deletions optimum/intel/openvino/modeling_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,9 @@ def from_pretrained(cls, model_name_or_path, **kwargs):
return cls(config, **kwargs)

def forward(self, pixel_values: Optional[torch.Tensor] = None):
logits = self.model(
pixel_values,
)
logits = self.model(pixel_values)

return ImageClassifierOutput(
logits=logits,
)
return ImageClassifierOutput(logits=logits)


# Adapted from ViTImageProcessor - https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/image_processing_vit.py
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@


_HEAD_TO_AUTOMODELS = {
"feature-extraction": "OVModelForFeatureExtraction",
"fill-mask": "OVModelForMaskedLM",
"text-generation": "OVModelForCausalLM",
"text2text-generation": "OVModelForSeq2SeqLM",
Expand All @@ -87,6 +88,7 @@
"image-classification": "OVModelForImageClassification",
"audio-classification": "OVModelForAudioClassification",
"stable-diffusion": "OVStableDiffusionPipeline",
"stable-diffusion-xl": "OVStableDiffusionXLPipeline",
}


Expand Down
38 changes: 28 additions & 10 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from utils_tests import MODEL_NAMES

from optimum.exporters.openvino.__main__ import main_export
from optimum.intel import ( # noqa
OVModelForAudioClassification,
OVModelForCausalLM,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS


class OVCLIExportTestCase(unittest.TestCase):
Expand All @@ -27,15 +41,17 @@ class OVCLIExportTestCase(unittest.TestCase):
"""

SUPPORTED_ARCHITECTURES = (
["causal-lm", "gpt2"],
["causal-lm-with-past", "gpt2"],
["seq2seq-lm", "t5"],
["seq2seq-lm-with-past", "t5"],
["sequence-classification", "bert"],
["text-generation", "gpt2"],
["text-generation-with-past", "gpt2"],
["text2text-generation", "t5"],
["text2text-generation-with-past", "t5"],
["text-classification", "bert"],
["question-answering", "distilbert"],
["masked-lm", "bert"],
["default", "blenderbot"],
["default-with-past", "blenderbot"],
["token-classification", "roberta"],
["image-classification", "vit"],
["audio-classification", "wav2vec2"],
["fill-mask", "bert"],
["feature-extraction", "blenderbot"],
["stable-diffusion", "stable-diffusion"],
["stable-diffusion-xl", "stable-diffusion-xl"],
["stable-diffusion-xl", "stable-diffusion-xl-refiner"],
Expand All @@ -51,9 +67,11 @@ def test_export(self, task: str, model_type: str):

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdirname:
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} {tmpdirname}",
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} {tmpdir}",
shell=True,
check=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)
Loading