Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Oct 2, 2023
1 parent 8801566 commit 821d2a9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def main_export(
if load_in_8bit is None:
if model_kwargs is None:
model_kwargs = {}

if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE:
model_kwargs["load_in_8bit"] = True
model_kwargs["load_in_8bit"] = True
else:
model_kwargs["load_in_8bit"] = False

Expand Down
12 changes: 4 additions & 8 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

from transformers.utils import is_tf_available, is_torch_available

import nncf
from transformers.utils import is_tf_available, is_torch_available

from openvino.runtime import PartialShape, save_model
from openvino.runtime.utils.types import get_element_type
Expand Down Expand Up @@ -199,7 +198,7 @@ def export_pytorch_via_onnx(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
load_in_8bit=model_kwargs.get("load_in_8bit", False)
load_in_8bit=model_kwargs.get("load_in_8bit", False),
)
return input_names, output_names, True

Expand Down Expand Up @@ -323,14 +322,11 @@ def ts_patched_forward(*args, **kwargs):
dims = inputs[input_name]

for dim in dims:
static_shape[dim] = -1
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
_save_model(ov_model,
output,
compress_to_fp16=False,
load_in_8bit=model_kwargs.get("load_in_8bit", False))
_save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=model_kwargs.get("load_in_8bit", False))
clear_class_registry()
del model
gc.collect()
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"pegasus",
}


@add_start_docstrings(
"""
Base OVBaseDecoderModel class.
Expand Down Expand Up @@ -225,7 +226,6 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"


main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand Down

0 comments on commit 821d2a9

Please sign in to comment.