Skip to content

Commit

Permalink
Add support whisper for openvino (#470)
Browse files Browse the repository at this point in the history
* add support whisper for openvino

* add test

* fix tests

* restrict transformers version for now...

* allow to run on GPU

* apply review comments

* fix compatibility with transformers 4.36

* fix generate

* apply comments

* fix pix2struct
  • Loading branch information
eaidova authored Dec 14, 2023
1 parent a30e337 commit b056177
Show file tree
Hide file tree
Showing 9 changed files with 705 additions and 45 deletions.
18 changes: 17 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,19 @@
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
from .convert import export_models


if is_optimum_version(">=", "1.16.0"):
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
else:
# Copied from https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py
SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [
"bart",
"whisper",
]

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_SIZE = 1e9
Expand Down Expand Up @@ -140,10 +149,12 @@ def main_export(
do_gptq_patching = False
try:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
model_type = config.model_type.replace("_", "-")
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
except Exception:
model_type = None
pass

if do_gptq_patching:
Expand Down Expand Up @@ -192,6 +203,10 @@ class StoreAttr(object):
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

loading_kwargs = {}
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
Expand All @@ -204,6 +219,7 @@ class StoreAttr(object):
trust_remote_code=trust_remote_code,
framework=framework,
device=device,
**loading_kwargs,
)

custom_architecture = False
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"OVModelForPix2Struct",
"OVModelForQuestionAnswering",
"OVModelForSeq2SeqLM",
"OVModelForSpeechSeq2Seq",
"OVModelForSequenceClassification",
"OVModelForTokenClassification",
]
Expand Down Expand Up @@ -195,6 +196,7 @@
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
OVModelForTokenClassification,
)
from .modeling_decoder import OVModelForCausalLM
from .modeling_seq2seq import OVModelForPix2Struct, OVModelForSeq2SeqLM
from .modeling_seq2seq import OVModelForPix2Struct, OVModelForSeq2SeqLM, OVModelForSpeechSeq2Seq


if is_diffusers_available():
Expand Down
2 changes: 0 additions & 2 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def __init__(
self.ov_config = ov_config if ov_config is not None else {}
self.preprocessors = kwargs.get("preprocessors", [])

if "GPU" in self._device:
raise ValueError("Support of dynamic shapes for GPU devices is not yet available.")
if self.is_dynamic:
encoder = self._reshape(encoder, -1, -1, is_decoder=False)
decoder = self._reshape(decoder, -1, -1)
Expand Down
Loading

0 comments on commit b056177

Please sign in to comment.