Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 29, 2023
1 parent 49f0eb7 commit d44e13c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 30 deletions.
87 changes: 70 additions & 17 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tempfile import TemporaryDirectory

from parameterized import parameterized
from utils_tests import MODEL_NAMES
from utils_tests import MODEL_NAMES, get_num_quantized_nodes

from optimum.exporters.openvino.__main__ import main_export
from optimum.intel import ( # noqa
Expand All @@ -35,31 +35,47 @@
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS


_ARCHITECTURES_TO_EXPECTED_INT8 = {
"bert": (34,),
"roberta": (34,),
"albert": (42,),
"vit": (31,),
"blenderbot": (35,),
"gpt2": (22,),
"wav2vec2": (15,),
"distilbert": (33,),
"t5": (32, 52, 42),
"stable-diffusion": (74, 4, 4, 32),
"stable-diffusion-xl": (148, 4, 4, 33),
"stable-diffusion-xl-refiner": (148, 4, 4, 33),
}


class OVCLIExportTestCase(unittest.TestCase):
"""
Integration tests ensuring supported models are correctly exported.
"""

SUPPORTED_ARCHITECTURES = (
["text-generation", "gpt2"],
["text-generation-with-past", "gpt2"],
["text2text-generation", "t5"],
["text2text-generation-with-past", "t5"],
["text-classification", "bert"],
["question-answering", "distilbert"],
["token-classification", "roberta"],
["image-classification", "vit"],
["audio-classification", "wav2vec2"],
["fill-mask", "bert"],
["feature-extraction", "blenderbot"],
["stable-diffusion", "stable-diffusion"],
["stable-diffusion-xl", "stable-diffusion-xl"],
["stable-diffusion-xl", "stable-diffusion-xl-refiner"],
("text-generation", "gpt2"),
("text-generation-with-past", "gpt2"),
("text2text-generation", "t5"),
("text2text-generation-with-past", "t5"),
("text-classification", "albert"),
("question-answering", "distilbert"),
("token-classification", "roberta"),
("image-classification", "vit"),
("audio-classification", "wav2vec2"),
("fill-mask", "bert"),
("feature-extraction", "blenderbot"),
("stable-diffusion", "stable-diffusion"),
("stable-diffusion-xl", "stable-diffusion-xl"),
("stable-diffusion-xl", "stable-diffusion-xl-refiner"),
)

def _openvino_export(self, model_name: str, task: str):
def _openvino_export(self, model_name: str, task: str, fp16: bool = False, int8: bool = False):
with TemporaryDirectory() as tmpdir:
main_export(model_name_or_path=model_name, output=tmpdir, task=task)
main_export(model_name_or_path=model_name, output=tmpdir, task=task, fp16=fp16, int8=int8)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_export(self, task: str, model_type: str):
Expand All @@ -75,3 +91,40 @@ def test_exporters_cli(self, task: str, model_type: str):
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli_fp16(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --fp16 {tmpdir}",
shell=True,
check=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli_int8(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --int8 {tmpdir}",
shell=True,
check=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

if task.startswith("text2text-generation"):
models = [model.encoder, model.decoder]
if task.endswith("with-past"):
model.append(model.decoder_with_past)
elif task.startswith("stable-diffusion"):
models = [model.unet, model.vae_encoder, model.vae_decoder]
models.append(model.text_encoder if task == "stable-diffusion" else model.text_encoder_2)
else:
models = [model]

expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type]
for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8[i], num_int8)
14 changes: 1 addition & 13 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,14 @@
OVTrainer,
)
from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
from utils_tests import get_num_quantized_nodes

_TASK_TO_DATASET = {
"text-generation": ("wikitext", "wikitext-2-raw-v1", "text"),
"text-classification": ("glue", "sst2", "sentence"),
}


def get_num_quantized_nodes(ov_model):
num_fake_quantize = 0
num_int8 = 0
for elem in ov_model.model.get_ops():
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
if "8" in elem.get_output_element_type(i).get_type_name():
num_int8 += 1
return num_fake_quantize, num_int8


class OVQuantizerTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
Expand Down Expand Up @@ -164,7 +153,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

# TODO: uncomment once move to a newer version of NNCF which has some fixes
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_pt_int8, num_int8)

Expand Down
13 changes: 13 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@
}

SEED = 42


def get_num_quantized_nodes(ov_model):
num_fake_quantize = 0
num_int8 = 0
for elem in ov_model.model.get_ops():
print(elem.name)
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
if "8" in elem.get_output_element_type(i).get_type_name():
num_int8 += 1
return num_fake_quantize, num_int8

0 comments on commit d44e13c

Please sign in to comment.