Skip to content

Commit

Permalink
merge main in branch
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 3, 2023
2 parents 3ddbafe + d207110 commit 7505a19
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
23 changes: 10 additions & 13 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_SIZE = 1e9

logger = logging.getLogger(__name__)


Expand All @@ -58,7 +60,7 @@ def main_export(
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = False,
int8: Optional[bool] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -138,6 +140,9 @@ def main_export(

import nncf

if model_kwargs is None:
model_kwargs = {}

output = Path(output)
if not output.exists():
output.mkdir(parents=True)
Expand Down Expand Up @@ -245,6 +250,9 @@ def main_export(
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}

if int8 is None:
int8 = (model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()) >= _MAX_UNCOMPRESSED_SIZE

if not is_stable_diffusion:
needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
Expand Down Expand Up @@ -312,17 +320,6 @@ def main_export(
input_shapes=input_shapes,
device=device,
fp16=fp16,
load_in_8bit=int8,
model_kwargs=model_kwargs,
)
del models_and_onnx_configs

if int8:
for model_path in files_subpaths:
model = core.read_model(output / model_path)
model = nncf.compress_weights(model)

for filename in (model_path, model_path.replace("xml", "bin")):
os.remove(output / filename)

openvino.save_model(model, output / model_path, compress_to_fp16=False)
del model
32 changes: 26 additions & 6 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.utils import is_diffusers_available

from ...intel.utils.import_utils import is_nncf_available
from .utils import (
OV_XML_FILE_NAME,
clear_class_registry,
Expand All @@ -52,6 +53,19 @@
from transformers.modeling_tf_utils import TFPreTrainedModel


def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False):
if load_in_8bit:
if not is_nncf_available():
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)
import nncf

model = nncf.compress_weights(model)
save_model(model, path, compress_to_fp16)


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
config: OnnxConfig,
Expand All @@ -61,6 +75,7 @@ def export(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -103,6 +118,7 @@ def export(
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
load_in_8bit=load_in_8bit,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -113,7 +129,7 @@ def export(
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
return export_tensorflow(model, config, opset, output, fp16=fp16)
return export_tensorflow(model, config, opset, output)

else:
raise RuntimeError(
Expand All @@ -126,7 +142,6 @@ def export_tensorflow(
config: OnnxConfig,
opset: int,
output: Path,
fp16: bool = False,
):
"""
Export the TensorFlow model to OpenVINO format.
Expand All @@ -145,7 +160,7 @@ def export_tensorflow(
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(ov_model, output.parent / output, compress_to_fp16=fp16)
_save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False)
return input_names, output_names, True


Expand All @@ -158,6 +173,7 @@ def export_pytorch_via_onnx(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand Down Expand Up @@ -196,10 +212,11 @@ def export_pytorch_via_onnx(
)
torch.onnx.export = orig_torch_onnx_export
ov_model = convert_model(str(onnx_output))
save_model(
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=fp16,
load_in_8bit=load_in_8bit,
)
return input_names, output_names, True

Expand All @@ -213,6 +230,7 @@ def export_pytorch(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -307,7 +325,7 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16)
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, load_in_8bit=load_in_8bit)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand All @@ -328,7 +346,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
save_model(ov_model, output, compress_to_fp16=fp16)
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=load_in_8bit)
clear_class_registry()
del model
gc.collect()
Expand All @@ -346,6 +364,7 @@ def export_models(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -391,6 +410,7 @@ def export_models(
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
load_in_8bit=load_in_8bit,
)
)

Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
)

config.is_decoder = True
Expand Down

0 comments on commit 7505a19

Please sign in to comment.