Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 and int8 to OpenVINO models and export CLI #443

Merged
merged 16 commits into from
Oct 4, 2023
4 changes: 4 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def parse_args_openvino(parser: "ArgumentParser"):
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"),
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"),


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -102,5 +104,7 @@ def run(self):
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
fp16=self.args.fp16,
int8=self.args.int8,
# **input_shapes,
)
43 changes: 24 additions & 19 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers.utils import is_torch_available

from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
Expand All @@ -34,13 +33,10 @@

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_DECODER_SIZE = 1e9
_MAX_UNCOMPRESSED_SIZE = 1e9

logger = logging.getLogger(__name__)

if is_torch_available():
import torch


def main_export(
model_name_or_path: str,
Expand All @@ -60,6 +56,7 @@ def main_export(
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -126,6 +123,13 @@ def main_export(
>>> main_export("gpt2", output="gpt2_onnx/")
```
"""
if int8 and not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

model_kwargs = model_kwargs or {}

output = Path(output)
if not output.exists():
output.mkdir(parents=True)
Expand All @@ -142,8 +146,6 @@ def main_export(
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
Expand All @@ -167,7 +169,6 @@ def main_export(
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
)

Expand Down Expand Up @@ -235,17 +236,19 @@ def main_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
model_kwargs = model_kwargs or {}
load_in_8bit = model_kwargs.get("load_in_8bit", None)
if load_in_8bit is None:
if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE:
if not is_nncf_available():
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)
else:
model_kwargs["load_in_8bit"] = True

if int8 is None:
int8 = False
num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
int8 = True
logger.info("The model weights will be quantized to int8.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down Expand Up @@ -313,5 +316,7 @@ def main_export(
output_names=files_subpaths,
input_shapes=input_shapes,
device=device,
fp16=fp16,
int8=int8,
model_kwargs=model_kwargs,
)
31 changes: 24 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def export(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -115,6 +117,8 @@ def export(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -133,7 +137,12 @@ def export(
)


def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path):
def export_tensorflow(
model: Union["PreTrainedModel", "ModelMixin"],
config: OnnxConfig,
opset: int,
output: Path,
):
"""
Export the TensorFlow model to OpenVINO format.

Expand Down Expand Up @@ -163,6 +172,8 @@ def export_pytorch_via_onnx(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand Down Expand Up @@ -201,12 +212,11 @@ def export_pytorch_via_onnx(
)
torch.onnx.export = orig_torch_onnx_export
ov_model = convert_model(str(onnx_output))
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
load_in_8bit=load_in_8bit,
compress_to_fp16=fp16,
load_in_8bit=int8,
)
return input_names, output_names, True

Expand All @@ -219,6 +229,8 @@ def export_pytorch(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -313,7 +325,9 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs)
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8
)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand All @@ -334,8 +348,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=load_in_8bit)
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8)
clear_class_registry()
del model
gc.collect()
Expand All @@ -352,6 +365,8 @@ def export_models(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -396,6 +411,8 @@ def export_models(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
)
)

Expand Down
21 changes: 17 additions & 4 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from optimum.modeling_base import OptimizedModel

from ...exporters.openvino import export, main_export
from ..utils.import_utils import is_transformers_version
from ..utils.import_utils import is_nncf_available, is_transformers_version
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME


Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

@staticmethod
def load_model(file_name: Union[str, Path]):
def load_model(file_name: Union[str, Path], load_in_8bit: bool = False):
"""
Loads the model.

Expand All @@ -120,6 +120,15 @@ def fix_op_names_duplicates(model: openvino.runtime.Model):
if file_name.suffix == ".onnx":
model = fix_op_names_duplicates(model) # should be called during model conversion to IR

if load_in_8bit:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)
import nncf

model = nncf.compress_weights(model)

return model

def _save_pretrained(self, save_directory: Union[str, Path]):
Expand All @@ -146,6 +155,7 @@ def _from_pretrained(
file_name: Optional[str] = None,
from_onnx: bool = False,
local_files_only: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -203,7 +213,8 @@ def _from_pretrained(
model_save_dir = Path(model_cache_path).parent
file_name = file_names[0]

model = cls.load_model(file_name)
model = cls.load_model(file_name, load_in_8bit=load_in_8bit)

return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)

@classmethod
Expand All @@ -219,6 +230,7 @@ def _from_transformers(
local_files_only: bool = False,
task: Optional[str] = None,
trust_remote_code: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -253,10 +265,11 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
int8=load_in_8bit,
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, **kwargs)
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)

@classmethod
def _to_load(
Expand Down
26 changes: 16 additions & 10 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _from_pretrained(
local_files_only: bool = False,
use_cache: bool = True,
from_onnx: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -159,14 +160,14 @@ def _from_pretrained(
encoder_file_name = encoder_file_name or default_encoder_file_name
decoder_file_name = decoder_file_name or default_decoder_file_name
decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name

decoder_with_past = None
# Load model from a local directory
if os.path.isdir(model_id):
encoder = cls.load_model(os.path.join(model_id, encoder_file_name))
decoder = cls.load_model(os.path.join(model_id, decoder_file_name))
decoder_with_past = (
cls.load_model(os.path.join(model_id, decoder_with_past_file_name)) if use_cache else None
)
encoder = cls.load_model(os.path.join(model_id, encoder_file_name), load_in_8bit)
decoder = cls.load_model(os.path.join(model_id, decoder_file_name), load_in_8bit)
if use_cache:
decoder_with_past = cls.load_model(os.path.join(model_id, decoder_with_past_file_name), load_in_8bit)

model_save_dir = Path(model_id)

# Load model from hub
Expand All @@ -193,9 +194,10 @@ def _from_pretrained(
file_names[name] = model_cache_path

model_save_dir = Path(model_cache_path).parent
encoder = cls.load_model(file_names["encoder"])
decoder = cls.load_model(file_names["decoder"])
decoder_with_past = cls.load_model(file_names["decoder_with_past"]) if use_cache else None
encoder = cls.load_model(file_names["encoder"], load_in_8bit)
decoder = cls.load_model(file_names["decoder"], load_in_8bit)
if use_cache:
decoder_with_past = cls.load_model(file_names["decoder_with_past"], load_in_8bit)

return cls(
encoder=encoder,
Expand All @@ -220,6 +222,7 @@ def _from_transformers(
task: Optional[str] = None,
use_cache: bool = True,
trust_remote_code: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -261,10 +264,13 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
int8=load_in_8bit,
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)
return cls._from_pretrained(
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs
)

def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
shapes = {}
Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def _from_transformers(
task: Optional[str] = None,
use_cache: bool = True,
trust_remote_code: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
if config.model_type not in _SUPPORTED_ARCHITECTURES:
Expand Down Expand Up @@ -238,12 +239,15 @@ def _from_transformers(
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
int8=load_in_8bit,
)

config.is_decoder = True
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)
return cls._from_pretrained(
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs
)

def _reshape(
self,
Expand Down
Loading