Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 28, 2023
1 parent cb04a62 commit 95b223c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@


_HEAD_TO_AUTOMODELS = {
"feature-extraction" : "OVModelForFeatureExtraction",
"feature-extraction": "OVModelForFeatureExtraction",
"fill-mask": "OVModelForMaskedLM",
"text-generation": "OVModelForCausalLM",
"text2text-generation": "OVModelForSeq2SeqLM",
Expand Down
11 changes: 6 additions & 5 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from utils_tests import MODEL_NAMES

from optimum.exporters.openvino.__main__ import main_export
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS
from optimum.intel import (
from optimum.intel import ( # noqa
OVModelForAudioClassification,
OVModelForCausalLM,
OVModelForFeatureExtraction,
Expand All @@ -33,6 +32,8 @@
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS


class OVCLIExportTestCase(unittest.TestCase):
"""
Expand Down Expand Up @@ -66,11 +67,11 @@ def test_export(self, task: str, model_type: str):

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdirname:
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} {tmpdirname}",
shell=True,
check=True,
)
model_kwargs = {"use_cache" : task.endswith("with-past")} if "generation" in task else {}
ov_model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdirname, **model_kwargs)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

0 comments on commit 95b223c

Please sign in to comment.