Skip to content

Commit

Permalink
fix compatibility with transformers 4.36
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 14, 2023
1 parent 7323be5 commit 931cc64
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@
from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast


try:
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
except ImportError:
# Duplicated from https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py
# until it is not part of package
SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [
"bart",
"whisper",
]
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.import_utils import is_nncf_available, is_transformers_version
from .convert import export_models


Expand Down Expand Up @@ -140,10 +151,12 @@ def main_export(
do_gptq_patching = False
try:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
model_type = config.model_type.replace("_", "-")
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
except Exception:
model_type = None
pass

if do_gptq_patching:
Expand Down Expand Up @@ -192,6 +205,10 @@ class StoreAttr(object):
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

loading_kwargs = {}
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
Expand All @@ -204,6 +221,7 @@ class StoreAttr(object):
trust_remote_code=trust_remote_code,
framework=framework,
device=device,
**loading_kwargs,
)

custom_architecture = False
Expand Down

0 comments on commit 931cc64

Please sign in to comment.