Skip to content

Commit

Permalink
add auto model loading to test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 28, 2023
1 parent 23b9627 commit cb04a62
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
2 changes: 2 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@


_HEAD_TO_AUTOMODELS = {
"feature-extraction" : "OVModelForFeatureExtraction",
"fill-mask": "OVModelForMaskedLM",
"text-generation": "OVModelForCausalLM",
"text2text-generation": "OVModelForSeq2SeqLM",
Expand All @@ -87,6 +88,7 @@
"image-classification": "OVModelForImageClassification",
"audio-classification": "OVModelForAudioClassification",
"stable-diffusion": "OVStableDiffusionPipeline",
"stable-diffusion-xl": "OVStableDiffusionXLPipeline",
}


Expand Down
35 changes: 26 additions & 9 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,38 @@
from utils_tests import MODEL_NAMES

from optimum.exporters.openvino.__main__ import main_export

from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS
from optimum.intel import (
OVModelForAudioClassification,
OVModelForCausalLM,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
)

class OVCLIExportTestCase(unittest.TestCase):
"""
Integration tests ensuring supported models are correctly exported.
"""

SUPPORTED_ARCHITECTURES = (
["causal-lm", "gpt2"],
["causal-lm-with-past", "gpt2"],
["seq2seq-lm", "t5"],
["seq2seq-lm-with-past", "t5"],
["sequence-classification", "bert"],
["text-generation", "gpt2"],
["text-generation-with-past", "gpt2"],
["text2text-generation", "t5"],
["text2text-generation-with-past", "t5"],
["text-classification", "bert"],
["question-answering", "distilbert"],
["masked-lm", "bert"],
["default", "blenderbot"],
["default-with-past", "blenderbot"],
["token-classification", "roberta"],
["image-classification", "vit"],
["audio-classification", "wav2vec2"],
["fill-mask", "bert"],
["feature-extraction", "blenderbot"],
["stable-diffusion", "stable-diffusion"],
["stable-diffusion-xl", "stable-diffusion-xl"],
["stable-diffusion-xl", "stable-diffusion-xl-refiner"],
Expand All @@ -57,3 +72,5 @@ def test_exporters_cli(self, task: str, model_type: str):
shell=True,
check=True,
)
model_kwargs = {"use_cache" : task.endswith("with-past")} if "generation" in task else {}
ov_model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdirname, **model_kwargs)

0 comments on commit cb04a62

Please sign in to comment.