Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 3, 2023
1 parent 745da20 commit 2644e0b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer

import openvino
from openvino import Core
from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
Expand Down Expand Up @@ -138,7 +137,6 @@ def main_export(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

import nncf

if model_kwargs is None:
model_kwargs = {}
Expand Down Expand Up @@ -251,7 +249,9 @@ def main_export(
models_and_onnx_configs = {"model": (model, onnx_config)}

if int8 is None:
int8 = (model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()) >= _MAX_UNCOMPRESSED_SIZE
int8 = (
model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()
) >= _MAX_UNCOMPRESSED_SIZE

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8)
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8
)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand Down

0 comments on commit 2644e0b

Please sign in to comment.