Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 and int8 to OpenVINO models and export CLI #443

Merged
merged 16 commits into from
Oct 4, 2023
4 changes: 4 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def parse_args_openvino(parser: "ArgumentParser"):
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"),
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"),


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -102,5 +104,7 @@ def run(self):
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
fp16=self.args.fp16,
int8=self.args.int8,
# **input_shapes,
)
39 changes: 32 additions & 7 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,26 @@

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers.utils import is_torch_available

import openvino
from openvino import Core
from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


core = Core()

OV_XML_FILE_NAME = "openvino_model.xml"

logger = logging.getLogger(__name__)

if is_torch_available():
import torch


def main_export(
model_name_or_path: str,
Expand All @@ -57,6 +58,7 @@ def main_export(
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -123,6 +125,19 @@ def main_export(
>>> main_export("gpt2", output="gpt2_onnx/")
```
"""
if int8:
if fp16:
raise ValueError(
"Both `fp16` and `int8` were both set to `True`, please select only one of these options."
)

if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

import nncf

output = Path(output)
if not output.exists():
output.mkdir(parents=True)
Expand All @@ -139,8 +154,6 @@ def main_export(
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
Expand All @@ -164,7 +177,6 @@ def main_export(
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
)

Expand Down Expand Up @@ -299,5 +311,18 @@ def main_export(
output_names=files_subpaths,
input_shapes=input_shapes,
device=device,
fp16=fp16,
model_kwargs=model_kwargs,
)
del models_and_onnx_configs

if int8:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment from my side. Such an implementation means that we will:

  • Convert from PyTorch to OpenVINO with memory reuse (no-copy)
  • However, after that, we will serialize to the disk, load the model again, quantize weights, remove files from the disk, and store the compressed version, finally.
  • It can be time-consuming for really large models.

Ideally, we should have nncf.compress_weights() right after openvino.convert_model(), the similar to FP16.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a follow-up PR #444 for the default 8-bit compression. I tried to follow the approach described above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eaidova, please take a look as well.

Copy link
Collaborator Author

@echarlaix echarlaix Oct 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should have nncf.compress_weights() right after openvino.convert_model(), the similar to FP16.

Yes that's a good point, was going to modify this PR but I see that you opened #444, thanks @AlexKoff88

for model_path in files_subpaths:
model = core.read_model(output / model_path)
model = nncf.compress_weights(model)

for filename in (model_path, model_path.replace("xml", "bin")):
os.remove(output / filename)

openvino.save_model(model, output / model_path, compress_to_fp16=False)
del model
24 changes: 18 additions & 6 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def export(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -101,6 +102,7 @@ def export(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -111,15 +113,21 @@ def export(
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
return export_tensorflow(model, config, opset, output)
return export_tensorflow(model, config, opset, output, fp16=fp16)

else:
raise RuntimeError(
"You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed."
)


def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path):
def export_tensorflow(
model: Union["PreTrainedModel", "ModelMixin"],
config: OnnxConfig,
opset: int,
output: Path,
fp16: bool = False,
):
"""
Export the TensorFlow model to OpenVINO format.

Expand All @@ -137,7 +145,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(ov_model, output.parent / output, compress_to_fp16=False)
save_model(ov_model, output.parent / output, compress_to_fp16=fp16)
return input_names, output_names, True


Expand All @@ -149,6 +157,7 @@ def export_pytorch_via_onnx(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand Down Expand Up @@ -190,7 +199,7 @@ def export_pytorch_via_onnx(
save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
compress_to_fp16=fp16,
)
return input_names, output_names, True

Expand All @@ -203,6 +212,7 @@ def export_pytorch(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -297,7 +307,7 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs)
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand All @@ -318,7 +328,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
save_model(ov_model, output, compress_to_fp16=False)
save_model(ov_model, output, compress_to_fp16=fp16)
clear_class_registry()
del model
gc.collect()
Expand All @@ -335,6 +345,7 @@ def export_models(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -379,6 +390,7 @@ def export_models(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
)
)

Expand Down
87 changes: 70 additions & 17 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tempfile import TemporaryDirectory

from parameterized import parameterized
from utils_tests import MODEL_NAMES
from utils_tests import MODEL_NAMES, get_num_quantized_nodes

from optimum.exporters.openvino.__main__ import main_export
from optimum.intel import ( # noqa
Expand All @@ -35,31 +35,47 @@
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS


_ARCHITECTURES_TO_EXPECTED_INT8 = {
"bert": (34,),
"roberta": (34,),
"albert": (42,),
"vit": (31,),
"blenderbot": (35,),
"gpt2": (22,),
"wav2vec2": (15,),
"distilbert": (33,),
"t5": (32, 52, 42),
"stable-diffusion": (74, 4, 4, 32),
"stable-diffusion-xl": (148, 4, 4, 33),
"stable-diffusion-xl-refiner": (148, 4, 4, 33),
}


class OVCLIExportTestCase(unittest.TestCase):
"""
Integration tests ensuring supported models are correctly exported.
"""

SUPPORTED_ARCHITECTURES = (
["text-generation", "gpt2"],
["text-generation-with-past", "gpt2"],
["text2text-generation", "t5"],
["text2text-generation-with-past", "t5"],
["text-classification", "bert"],
["question-answering", "distilbert"],
["token-classification", "roberta"],
["image-classification", "vit"],
["audio-classification", "wav2vec2"],
["fill-mask", "bert"],
["feature-extraction", "blenderbot"],
["stable-diffusion", "stable-diffusion"],
["stable-diffusion-xl", "stable-diffusion-xl"],
["stable-diffusion-xl", "stable-diffusion-xl-refiner"],
("text-generation", "gpt2"),
("text-generation-with-past", "gpt2"),
("text2text-generation", "t5"),
("text2text-generation-with-past", "t5"),
("text-classification", "albert"),
("question-answering", "distilbert"),
("token-classification", "roberta"),
("image-classification", "vit"),
("audio-classification", "wav2vec2"),
("fill-mask", "bert"),
("feature-extraction", "blenderbot"),
("stable-diffusion", "stable-diffusion"),
("stable-diffusion-xl", "stable-diffusion-xl"),
("stable-diffusion-xl", "stable-diffusion-xl-refiner"),
)

def _openvino_export(self, model_name: str, task: str):
def _openvino_export(self, model_name: str, task: str, fp16: bool = False, int8: bool = False):
with TemporaryDirectory() as tmpdir:
main_export(model_name_or_path=model_name, output=tmpdir, task=task)
main_export(model_name_or_path=model_name, output=tmpdir, task=task, fp16=fp16, int8=int8)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_export(self, task: str, model_type: str):
Expand All @@ -75,3 +91,40 @@ def test_exporters_cli(self, task: str, model_type: str):
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli_fp16(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --fp16 {tmpdir}",
shell=True,
check=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli_int8(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --int8 {tmpdir}",
shell=True,
check=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

if task.startswith("text2text-generation"):
models = [model.encoder, model.decoder]
if task.endswith("with-past"):
models.append(model.decoder_with_past)
elif task.startswith("stable-diffusion"):
models = [model.unet, model.vae_encoder, model.vae_decoder]
models.append(model.text_encoder if task == "stable-diffusion" else model.text_encoder_2)
else:
models = [model]

expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type]
for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8[i], num_int8)
14 changes: 1 addition & 13 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,14 @@
OVTrainer,
)
from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
from utils_tests import get_num_quantized_nodes

_TASK_TO_DATASET = {
"text-generation": ("wikitext", "wikitext-2-raw-v1", "text"),
"text-classification": ("glue", "sst2", "sentence"),
}


def get_num_quantized_nodes(ov_model):
num_fake_quantize = 0
num_int8 = 0
for elem in ov_model.model.get_ops():
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
if "8" in elem.get_output_element_type(i).get_type_name():
num_int8 += 1
return num_fake_quantize, num_int8


class OVQuantizerTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
Expand Down Expand Up @@ -164,7 +153,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

# TODO: uncomment once move to a newer version of NNCF which has some fixes
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_pt_int8, num_int8)

Expand Down
13 changes: 13 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@
}

SEED = 42


def get_num_quantized_nodes(ov_model):
num_fake_quantize = 0
num_int8 = 0
for elem in ov_model.model.get_ops():
print(elem.name)
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
if "8" in elem.get_output_element_type(i).get_type_name():
num_int8 += 1
return num_fake_quantize, num_int8