Skip to content

Commit

Permalink
Fixed issue
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Oct 2, 2023
1 parent 821d2a9 commit aa0c6ad
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,12 @@ def export_pytorch_via_onnx(
)
torch.onnx.export = orig_torch_onnx_export
ov_model = convert_model(str(onnx_output))
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
load_in_8bit=model_kwargs.get("load_in_8bit", False),
load_in_8bit=load_in_8bit,
)
return input_names, output_names, True

Expand Down Expand Up @@ -326,7 +327,8 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
_save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=model_kwargs.get("load_in_8bit", False))
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=load_in_8bit)
clear_class_registry()
del model
gc.collect()
Expand Down

0 comments on commit aa0c6ad

Please sign in to comment.