Skip to content

Commit

Permalink
fix version checking if openvino not in site-packages
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 21, 2023
1 parent 6b0236b commit ff1c737
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
5 changes: 1 addition & 4 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,7 @@ def ts_patched_forward(*args, **kwargs):
ov_model.validate_nodes_and_infer_types()

if stateful:
# Patching model according to stateful parameters
model.key_value_input_names = [name for name in input_names if name.startswith("past_key_values.")]
model.key_value_output_names = [name for name in output_names if name.startswith("present.")]
patch_stateful(model, ov_model)
patch_stateful(model.config, ov_model)

_save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio)
clear_class_registry()
Expand Down
10 changes: 5 additions & 5 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import openvino as ov
from openvino.runtime import opset13
from optimum.intel.utils.import_utils import is_openvino_version
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
from optimum.utils.normalized_config import NormalizedConfigManager


Expand Down Expand Up @@ -121,20 +121,20 @@ def make_stateful(


def raise_if_openvino_is_too_old():
if is_openvino_version("<=", "2023.2"):
if is_openvino_version("<", "2023.3"):
raise ValueError(
f"Could not create or use stateful model when using old version of openvino=={ov.__version__}. Install openvino>=2023.3.0."
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0."
)


def patch_stateful(config, ov_model):
raise_if_openvino_is_too_old()

key_value_input_names = [
key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.names)
key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names())
]
key_value_output_names = [
key.get_any_name() for key in ov_model.output if any("present" in key_name for key_name in key.names)
key.get_any_name() for key in ov_model.outputs if any("present" in key_name for key_name in key.get_names())
]
not_kv_inputs = [
input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
Expand Down
17 changes: 11 additions & 6 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,17 @@
_openvino_version = "N/A"
if _openvino_available:
try:
_openvino_version = importlib_metadata.version("openvino")
except importlib_metadata.PackageNotFoundError:
try:
_openvino_version = importlib_metadata.version("openvino-nightly")
except importlib_metadata.PackageNotFoundError:
_openvino_available = False
from openvino.runtime import get_version

version = get_version()
# avoid invalid format
if "-" in version:
major_version, dev_info = version.split("-", 1)
commit_id = dev_info.split("-")[0]
version = f"{major_version}-{commit_id}"
_openvino_version = version
except ImportError:
_openvino_available = False


_nncf_available = importlib.util.find_spec("nncf") is not None
Expand Down

0 comments on commit ff1c737

Please sign in to comment.