diff --git a/src/ota_metadata/legacy/__init__.py b/src/ota_metadata/legacy/__init__.py index ff5f131de..6e383fff3 100644 --- a/src/ota_metadata/legacy/__init__.py +++ b/src/ota_metadata/legacy/__init__.py @@ -16,32 +16,19 @@ from __future__ import annotations -import importlib.util import sys from pathlib import Path -from types import ModuleType -_PROTO_DIR = Path(__file__).parent -_PB2_FPATH = _PROTO_DIR / "ota_metafiles_pb2.py" -_PACKAGE_PREFIX = "ota_metadata.legacy" +from otaclient_common import import_from_file +SUPORTED_COMPRESSION_TYPES = ("zst", "zstd") -def _import_from_file(path: Path) -> tuple[str, ModuleType]: - if not path.is_file(): - raise ValueError(f"{path} is not a valid module file") - try: - _module_name = path.stem - _spec = importlib.util.spec_from_file_location(_module_name, path) - _module = importlib.util.module_from_spec(_spec) # type: ignore - _spec.loader.exec_module(_module) # type: ignore - return _module_name, _module - except Exception: - raise ImportError(f"failed to import module from {path=}.") +# ------ dynamically import pb2 generated code ------ # +_PROTO_DIR = Path(__file__).parent +_PB2_FPATH = _PROTO_DIR / "ota_metafiles_pb2.py" +_PACKAGE_PREFIX = ".".join(__name__.split(".")[:-1]) -_module_name, _module = _import_from_file(_PB2_FPATH) # noqa: F821 +_module_name, _module = import_from_file(_PB2_FPATH) sys.modules[_module_name] = _module sys.modules[f"{_PACKAGE_PREFIX}.{_module_name}"] = _module - -# For OTA image format legacy, we support zst,zstd compression. -SUPORTED_COMPRESSION_TYPES = ("zst", "zstd") diff --git a/src/otaclient_api/v2/__init__.py b/src/otaclient_api/v2/__init__.py index 525453a23..ff3cd5252 100644 --- a/src/otaclient_api/v2/__init__.py +++ b/src/otaclient_api/v2/__init__.py @@ -16,34 +16,20 @@ from __future__ import annotations -import importlib.util import sys from pathlib import Path -from types import ModuleType + +from otaclient_common import import_from_file + +# ------ dynamically import pb2 generated code ------ # _PROTO_DIR = Path(__file__).parent -# NOTE: order matters here! v2_pb2_grpc depends on v2_pb2 +# NOTE: order matters here! pb2_grpc depends on pb2 _FILES_TO_LOAD = [ - _PROTO_DIR / _fname - for _fname in [ - "otaclient_v2_pb2.py", - "otaclient_v2_pb2_grpc.py", - ] + _PROTO_DIR / "otaclient_v2_pb2.py", + _PROTO_DIR / "otaclient_v2_pb2_grpc.py", ] -PACKAGE_PREFIX = "otaclient_api.v2" - - -def _import_from_file(path: Path) -> tuple[str, ModuleType]: - if not path.is_file(): - raise ValueError(f"{path} is not a valid module file") - try: - _module_name = path.stem - _spec = importlib.util.spec_from_file_location(_module_name, path) - _module = importlib.util.module_from_spec(_spec) # type: ignore - _spec.loader.exec_module(_module) # type: ignore - return _module_name, _module - except Exception: - raise ImportError(f"failed to import module from {path=}.") +PACKAGE_PREFIX = ".".join(__name__.split(".")[:-1]) def _import_pb2_proto(*module_fpaths: Path): @@ -53,12 +39,10 @@ def _import_pb2_proto(*module_fpaths: Path): imported as modules to the global namespace. """ for _fpath in module_fpaths: - _module_name, _module = _import_from_file(_fpath) # noqa: F821 - # add the module under the otaclient_api.v2 package + _module_name, _module = import_from_file(_fpath) sys.modules[f"{PACKAGE_PREFIX}.{_module_name}"] = _module - # add the module to the global module namespace sys.modules[_module_name] = _module _import_pb2_proto(*_FILES_TO_LOAD) -del _import_pb2_proto, _import_from_file +del _import_pb2_proto diff --git a/src/otaclient_common/__init__.py b/src/otaclient_common/__init__.py index ac7f63876..55144770d 100644 --- a/src/otaclient_common/__init__.py +++ b/src/otaclient_common/__init__.py @@ -16,9 +16,12 @@ from __future__ import annotations +import importlib.util import os +import sys from math import ceil from pathlib import Path +from types import ModuleType from typing import Optional from typing_extensions import Literal @@ -58,3 +61,16 @@ def replace_root(path: str | Path, old_root: str | Path, new_root: str | Path) - if os.path.commonpath([path, old_root]) != old_root: raise ValueError(f"{old_root=} is not the root of {path=}") return os.path.join(new_root, os.path.relpath(path, old_root)) + + +def import_from_file(path: Path) -> tuple[str, ModuleType]: + if not path.is_file(): + raise ValueError(f"{path} is not a valid module file") + try: + _module_name = path.stem + _spec = importlib.util.spec_from_file_location(_module_name, path) + _module = importlib.util.module_from_spec(_spec) # type: ignore + _spec.loader.exec_module(_module) # type: ignore + return _module_name, _module + except Exception: + raise ImportError(f"failed to import module from {path=}.")