Skip to content

Commit

Permalink
rename var
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 3, 2023
1 parent 7505a19 commit 745da20
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,6 @@ def main_export(
input_shapes=input_shapes,
device=device,
fp16=fp16,
load_in_8bit=int8,
int8=int8,
model_kwargs=model_kwargs,
)
18 changes: 9 additions & 9 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def export(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
int8: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -118,7 +118,7 @@ def export(
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
load_in_8bit=load_in_8bit,
int8=int8,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand Down Expand Up @@ -173,7 +173,7 @@ def export_pytorch_via_onnx(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
int8: bool = False,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand Down Expand Up @@ -216,7 +216,7 @@ def export_pytorch_via_onnx(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=fp16,
load_in_8bit=load_in_8bit,
load_in_8bit=int8,
)
return input_names, output_names, True

Expand All @@ -230,7 +230,7 @@ def export_pytorch(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
int8: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -325,7 +325,7 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, load_in_8bit=load_in_8bit)
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand All @@ -346,7 +346,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=load_in_8bit)
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8)
clear_class_registry()
del model
gc.collect()
Expand All @@ -364,7 +364,7 @@ def export_models(
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
load_in_8bit: bool = False,
int8: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -410,7 +410,7 @@ def export_models(
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
load_in_8bit=load_in_8bit,
int8=int8,
)
)

Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def _from_transformers(
local_files_only: bool = False,
task: Optional[str] = None,
trust_remote_code: bool = False,
# load_in_8bit: bool = False, # TODO : add int8
**kwargs,
):
"""
Expand Down

0 comments on commit 745da20

Please sign in to comment.