From 9d3207f4ad030fc6c43b6b58d777dbd8ce3ba803 Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 8 Dec 2023 15:01:23 +0400 Subject: [PATCH] Fix compatibility causallm models export with optimum 1.15 --- optimum/exporters/openvino/convert.py | 8 +++++--- optimum/intel/utils/import_utils.py | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 1c37473d24..70c8df30a9 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -31,7 +31,7 @@ from optimum.exporters.onnx.model_patcher import DecoderModelPatcher from optimum.utils import is_diffusers_available -from ...intel.utils.import_utils import is_nncf_available +from ...intel.utils.import_utils import is_nncf_available, is_optimum_version from .utils import ( OV_XML_FILE_NAME, clear_class_registry, @@ -307,8 +307,10 @@ def export_pytorch( # model.config.torchscript = True can not be used for patching, because it overrides return_dict to Flase if custom_patcher or dict_inputs: patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs) - # DecoderModelPatcher does not override model forward - if isinstance(patcher, DecoderModelPatcher) or patcher.orig_forward_name != "forward": + # DecoderModelPatcher does not override model forward in optimum < 1.15 + if ( + isinstance(patcher, DecoderModelPatcher) and is_optimum_version("<", "1.15.0") + ) or patcher.orig_forward_name != "forward": patch_model_forward = True patched_forward = model.forward else: diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index d15780384f..f778bbfcbd 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -29,6 +29,7 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} +_optimum_version = importlib_metadata.version("optimum") _transformers_available = importlib.util.find_spec("transformers") is not None _transformers_version = "N/A" @@ -175,6 +176,10 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +def is_optimum_version(operation: str, version: str): + return compare_versions(parse(_optimum_version), operation, version) + + def is_neural_compressor_version(operation: str, version: str): """ Compare the current Neural Compressor version to a given reference with an operation.