diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 85196eea0..000000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -ota-image.* diff --git a/.gitignore b/.gitignore index 12e4e8092..5787b2cc8 100644 --- a/.gitignore +++ b/.gitignore @@ -159,10 +159,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -# build related -build -*.egg-info - # local vscode configs .devcontainer .vscode diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4348be26d..610b2bb88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - flake8-comprehensions - flake8-simplify - repo: https://github.com/tox-dev/pyproject-fmt - rev: "1.8.0" + rev: "2.1.3" hooks: - id: pyproject-fmt # https://pyproject-fmt.readthedocs.io/en/latest/#calculating-max-supported-python-version @@ -39,7 +39,7 @@ repos: # additional_dependencies: # - tomli - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.40.0 + rev: v0.41.0 hooks: - id: markdownlint args: ["-c", ".markdownlint.yaml", "--fix"] diff --git a/README.md b/README.md index c5a57f6aa..0e6eace46 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,23 @@ -# OTA client +# OTAClient ## Overview -This OTA client is a client software to perform over-the-air software updates for linux devices. -To enable updating of software at any layer (kernel, kernel module, user library, user application), the OTA client targets the entire rootfs for updating. -When the OTA client receives an update request, it downloads a list from the OTA server that contains the file paths and the hash values of the files, etc., to be updated, and compares them with the files in its own storage and if there is a match, that file is used to update the rootfs. By this delta mechanism, it is possible to reduce the download size even if the entire rootfs is targeted and this mechanism does not require any specific server implementation, nor does it require the server to keep a delta for each version of the rootfs. +OTAClient is software to perform over-the-air software updates for linux devices. +It provides a set of APIs for user to start the OTA and monitor the progress and status. + +It is designed to work with web.auto FMS OTA component. ## Feature -- Rootfs updating -- Delta updating -- Redundant configuration with A/B partition update -- Arbitrary files can be copied from A to B partition. This can be used to take over individual files. -- No specific server implementation is required. The server that supports HTTP GET is only required. - - TLS connection is also required. -- Delta management is not required for server side. -- To restrict access to the server, cookie can be used. -- All files to be updated are verified by the hash included in the metadata, and the metadata is also verified by X.509 certificate locally installed. -- Transfer data is encrypted by TLS -- Multiple ECU(Electronic Control Unit) support -- By the internal proxy cache mechanism, the cache can be used for the download requests to the same file from multiple ECU. +- A/B partition update with support for generic x86_64 device, NVIDIA Jetson series based devices and Raspberry Pi device. +- Full Rootfs update, with delta update support. +- Local delta calculation, allowing update to any version of OTA image without the need of a pre-generated delta OTA package. +- Support persist files from active slot to newly updated slot. +- Verification over OTA image by digital signature and PKI. +- Support for protected OTA server with cookie. +- Optional OTA proxy support and OTA cache support. +- Multiple ECU OTA supports. ## License -OTA client is licensed under the Apache License, Version 2.0. +OTAClient is licensed under the Apache License, Version 2.0. diff --git a/bootstrap/root/boot/ota/ecu_info.yaml b/bootstrap/root/boot/ota/ecu_info.yaml deleted file mode 100644 index 85f520012..000000000 --- a/bootstrap/root/boot/ota/ecu_info.yaml +++ /dev/null @@ -1,7 +0,0 @@ -format_version: 1 -ecu_id: "autoware" -#secondaries: -# - ecu_id: "perception1" -# ip_addr: "192.168.0.11" -# - ecu_id: "perception2" -# ip_addr: "192.168.0.12" diff --git a/proto/README.md b/proto/README.md new file mode 100644 index 000000000..fd819de29 --- /dev/null +++ b/proto/README.md @@ -0,0 +1,3 @@ +# OTA Service API proto + +This folder includes the OTA service API proto file, and a set of tools to generate the python lib from the proto files. diff --git a/pyproject.toml b/pyproject.toml index 41e11a020..40d2e2bea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,22 +25,21 @@ dynamic = [ ] dependencies = [ "aiofiles==22.1", - "aiohttp<3.10.0,>=3.9.5", - "cryptography<43.0.0,>=42.0.4", - "grpcio<1.54.0,>=1.53.2", - "protobuf<4.22.0,>=4.21.12", + "aiohttp<3.10,>=3.9.5", + "cryptography<43,>=42.0.4", + "grpcio<1.54,>=1.53.2", + "protobuf<4.22,>=4.21.12", "pydantic==2.7", "pydantic-settings==2.2.1", - "pyOpenSSL==24.1", - "PyYAML>=3.12", - "requests<2.32.0,>=2.31", - 'typing_extensions>=4.6.3; python_version < "3.11"', - "urllib3<2.0.0,>=1.26.8", + "pyopenssl==24.1", + "pyyaml>=3.12", + "requests<2.32,>=2.31", + "typing-extensions>=4.6.3", + "urllib3<2,>=1.26.8", "uvicorn[standard]==0.20", "zstandard==0.18", ] -[project.optional-dependencies] -dev = [ +optional-dependencies.dev = [ "black", "coverage", "flake8", @@ -50,8 +49,7 @@ dev = [ "pytest-mock==3.8.2", "requests-mock", ] -[project.urls] -Source = "https://github.com/tier4/ota-client" +urls.Source = "https://github.com/tier4/ota-client" [tool.hatch.version] source = "vcs" @@ -60,26 +58,44 @@ source = "vcs" version-file = "src/_otaclient_version.py" [tool.hatch.build.targets.sdist] -exclude = ["/tools"] +exclude = [ + "/tools", + ".github", +] [tool.hatch.build.targets.wheel] -only-include = ["src"] -sources = ["src"] +exclude = [ + "**/.gitignore", + "**/*README.md", +] +only-include = [ + "src", +] +sources = [ + "src", +] [tool.hatch.envs.dev] type = "virtual" -features = ["dev"] +features = [ + "dev", +] [tool.black] line-length = 88 -target-version = ['py38'] +target-version = [ + 'py38', +] extend-exclude = '''( ^.*(_pb2.pyi?|_pb2_grpc.pyi?)$ )''' [tool.isort] profile = "black" -extend_skip_glob = ["*_pb2.py*", "_pb2_grpc.py*"] +extend_skip_glob = [ + "*_pb2.py*", + "_pb2_grpc.py*", +] [tool.pytest.ini_options] asyncio_mode = "auto" @@ -87,16 +103,26 @@ log_auto_indent = true log_format = "%(asctime)s %(levelname)s %(filename)s %(funcName)s,%(lineno)d %(message)s" log_cli = true log_cli_level = "INFO" -pythonpath = ["otaclient"] -testpaths = ["./tests"] +testpaths = [ + "./tests", +] [tool.coverage.run] branch = false relative_files = true -source = ["otaclient"] +source = [ + "otaclient", + "otaclient_api", + "otaclient_common", + "ota_metadata", + "ota_proxy", +] [tool.coverage.report] -omit = ["**/*_pb2.py*", "**/*_pb2_grpc.py*"] +omit = [ + "**/*_pb2.py*", + "**/*_pb2_grpc.py*", +] exclude_also = [ "def __repr__", "if __name__ == .__main__.:", @@ -108,6 +134,11 @@ skip_empty = true skip_covered = true [tool.pyright] -exclude = ["**/__pycache__"] -ignore = ["**/*_pb2.py*", "**/*_pb2_grpc.py*"] +exclude = [ + "**/__pycache__", +] +ignore = [ + "**/*_pb2.py*", + "**/*_pb2_grpc.py*", +] pythonVersion = "3.8" diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 000000000..859d9a6ea --- /dev/null +++ b/samples/README.md @@ -0,0 +1,3 @@ +# OTAClient configuration files samples + +This folder contains the sample otaclient configuration files **ecu_info.yaml**, **proxy_info.yaml** and systemd service unit file **otaclient.service** for a single ECU OTA setup. diff --git a/samples/ecu_info.yaml b/samples/ecu_info.yaml new file mode 100644 index 000000000..46b3700fb --- /dev/null +++ b/samples/ecu_info.yaml @@ -0,0 +1,7 @@ +# This is the sample ecu_info.yaml for a single x86_64 ECU setup. +# Please check ecu_info.yaml spec for more details: https://tier4.atlassian.net/l/cp/AGmpqFFc. +format_version: 1 +ecu_id: autoware +bootloader: grub +available_ecu_ids: + - autoware diff --git a/bootstrap/root/etc/systemd/system/otaclient.service b/samples/otaclient.service similarity index 56% rename from bootstrap/root/etc/systemd/system/otaclient.service rename to samples/otaclient.service index 39e1552dd..7bb9d28bb 100644 --- a/bootstrap/root/etc/systemd/system/otaclient.service +++ b/samples/otaclient.service @@ -1,5 +1,3 @@ -# otaclient.service - [Unit] Description=OTA Client After=network-online.target nss-lookup.target @@ -7,9 +5,9 @@ Wants=network-online.target [Service] Type=simple -ExecStart=/bin/bash -c 'source /opt/ota/.venv/bin/activate && PYTHONPATH=/opt/ota python3 -m otaclient' +ExecStart=/opt/ota/client/venv/bin/python3 -m otaclient Restart=always -RestartSec=10 +RestartSec=16 [Install] WantedBy=multi-user.target diff --git a/samples/proxy_info.yaml b/samples/proxy_info.yaml new file mode 100644 index 000000000..5c893ca05 --- /dev/null +++ b/samples/proxy_info.yaml @@ -0,0 +1,9 @@ +# This is the sample proxy_info.yaml for a single ECU setup. +# Please check proxy_info.yaml spec for more details: https://tier4.atlassian.net/l/cp/qT4N4K0X. +format_version: 1 +enable_local_ota_proxy: true +enable_local_ota_proxy_cache: true +local_ota_proxy_listen_addr: 127.0.0.1 +local_ota_proxy_listen_port: 8082 +# if otaclient-logger is installed locally +logging_server: "http://127.0.0.1:8083" diff --git a/src/ota_metadata/README.md b/src/ota_metadata/README.md new file mode 100644 index 000000000..84dcc9084 --- /dev/null +++ b/src/ota_metadata/README.md @@ -0,0 +1,3 @@ +# OTA image metadata + +Libs for parsing OTA image. diff --git a/tools/emulator/path_loader.py b/src/ota_metadata/legacy/__init__.py similarity index 52% rename from tools/emulator/path_loader.py rename to src/ota_metadata/legacy/__init__.py index dd93760fb..6e383fff3 100644 --- a/tools/emulator/path_loader.py +++ b/src/ota_metadata/legacy/__init__.py @@ -11,19 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""OTA image metadata, legacy version.""" -# NOTE: this file should only be loaded once by the program entry! +from __future__ import annotations +import sys +from pathlib import Path -###### load path ###### -def _path_load(): - import sys - from pathlib import Path +from otaclient_common import import_from_file - project_base = Path(__file__).absolute().parent.parent - sys.path.extend([str(project_base), str(project_base / "app")]) +SUPORTED_COMPRESSION_TYPES = ("zst", "zstd") +# ------ dynamically import pb2 generated code ------ # -_path_load() -###### +_PROTO_DIR = Path(__file__).parent +_PB2_FPATH = _PROTO_DIR / "ota_metafiles_pb2.py" +_PACKAGE_PREFIX = ".".join(__name__.split(".")[:-1]) + +_module_name, _module = import_from_file(_PB2_FPATH) +sys.modules[_module_name] = _module +sys.modules[f"{_PACKAGE_PREFIX}.{_module_name}"] = _module diff --git a/src/otaclient/app/proto/ota_metafiles_pb2.py b/src/ota_metadata/legacy/ota_metafiles_pb2.py similarity index 100% rename from src/otaclient/app/proto/ota_metafiles_pb2.py rename to src/ota_metadata/legacy/ota_metafiles_pb2.py diff --git a/src/otaclient/app/proto/ota_metafiles_pb2.pyi b/src/ota_metadata/legacy/ota_metafiles_pb2.pyi similarity index 100% rename from src/otaclient/app/proto/ota_metafiles_pb2.pyi rename to src/ota_metadata/legacy/ota_metafiles_pb2.pyi diff --git a/src/otaclient/app/ota_metadata.py b/src/ota_metadata/legacy/parser.py similarity index 94% rename from src/otaclient/app/ota_metadata.py rename to src/ota_metadata/legacy/parser.py index b69a99241..25283004c 100644 --- a/src/otaclient/app/ota_metadata.py +++ b/src/ota_metadata/legacy/parser.py @@ -72,18 +72,17 @@ from typing_extensions import Self from ota_proxy import OTAFileCacheControl - -from .common import RetryTaskMap, get_backoff, urljoin_ensure_base -from .configs import config as cfg -from .downloader import Downloader -from .proto.streamer import Uint32LenDelimitedMsgReader, Uint32LenDelimitedMsgWriter -from .proto.wrapper import ( - DirectoryInf, - MessageWrapper, - PersistentInf, - RegularInf, - SymbolicLinkInf, +from otaclient_common.common import get_backoff, urljoin_ensure_base +from otaclient_common.downloader import Downloader +from otaclient_common.proto_streamer import ( + Uint32LenDelimitedMsgReader, + Uint32LenDelimitedMsgWriter, ) +from otaclient_common.proto_wrapper import MessageWrapper +from otaclient_common.retry_task_map import RetryTaskMap + +from . import SUPORTED_COMPRESSION_TYPES +from .types import DirectoryInf, PersistentInf, RegularInf, SymbolicLinkInf logger = logging.getLogger(__name__) @@ -592,10 +591,25 @@ class OTAMetadata: ), } - def __init__(self, *, url_base: str, downloader: Downloader) -> None: + MAX_COCURRENT = 2 + BACKOFF_FACTOR = 1 + BACKOFF_MAX = 6 + + def __init__( + self, + *, + url_base: str, + downloader: Downloader, + run_dir: Path, + certs_dir: Path, + download_max_idle_time: int, + ) -> None: self.url_base = url_base self._downloader = downloader - self._tmp_dir = TemporaryDirectory(prefix="ota_metadata", dir=cfg.RUN_DIR) + self.run_dir = run_dir + self.certs_dir = certs_dir + self.download_max_idle_time = download_max_idle_time + self._tmp_dir = TemporaryDirectory(prefix="ota_metadata", dir=run_dir) self._tmp_dir_path = Path(self._tmp_dir.name) # download and parse the metadata.jwt @@ -622,7 +636,7 @@ def _process_metadata_jwt(self) -> _MetadataJWTClaimsLayout: """Download, loading and parsing metadata.jwt.""" logger.debug("process metadata.jwt...") # download and parse metadata.jwt - with NamedTemporaryFile(prefix="metadata_jwt", dir=cfg.RUN_DIR) as meta_f: + with NamedTemporaryFile(prefix="metadata_jwt", dir=self.run_dir) as meta_f: _downloaded_meta_f = Path(meta_f.name) self._downloader.download_retry_inf( urljoin_ensure_base(self.url_base, self.METADATA_JWT), @@ -636,13 +650,13 @@ def _process_metadata_jwt(self) -> _MetadataJWTClaimsLayout: ) _parser = _MetadataJWTParser( - _downloaded_meta_f.read_text(), certs_dir=cfg.CERTS_DIR + _downloaded_meta_f.read_text(), certs_dir=self.certs_dir ) # get not yet verified parsed ota_metadata _ota_metadata = _parser.get_otametadata() # download certificate and verify metadata against this certificate - with NamedTemporaryFile(prefix="metadata_cert", dir=cfg.RUN_DIR) as cert_f: + with NamedTemporaryFile(prefix="metadata_cert", dir=self.run_dir) as cert_f: cert_info = _ota_metadata.certificate cert_fname, cert_hash = cert_info.file, cert_info.hash cert_file = Path(cert_f.name) @@ -696,11 +710,11 @@ def _process_text_base_otameta_file(_metafile: MetaFile): last_active_timestamp = int(time.time()) _mapper = RetryTaskMap( - max_concurrent=cfg.MAX_CONCURRENT_DOWNLOAD_TASKS, + max_concurrent=self.MAX_COCURRENT, backoff_func=partial( get_backoff, - factor=cfg.DOWNLOAD_GROUP_BACKOFF_FACTOR, - _max=cfg.DOWNLOAD_GROUP_BACKOFF_MAX, + factor=self.BACKOFF_FACTOR, + _max=self.BACKOFF_MAX, ), max_retry=0, # NOTE: we use another strategy below ) @@ -718,12 +732,9 @@ def _process_text_base_otameta_file(_metafile: MetaFile): last_active_timestamp = max( last_active_timestamp, self._downloader.last_active_timestamp ) - if ( - int(time.time()) - last_active_timestamp - > cfg.DOWNLOAD_GROUP_INACTIVE_TIMEOUT - ): + if int(time.time()) - last_active_timestamp > self.download_max_idle_time: logger.error( - f"downloader becomes stuck for {cfg.DOWNLOAD_GROUP_INACTIVE_TIMEOUT=} seconds, abort" + f"downloader becomes stuck for {self.download_max_idle_time=} seconds, abort" ) _mapper.shutdown(raise_last_exc=True) @@ -753,7 +764,7 @@ def get_download_url(self, reg_inf: RegularInf) -> Tuple[str, Optional[str]]: if ( self.image_compressed_rootfs_url and reg_inf.compressed_alg - and reg_inf.compressed_alg in cfg.SUPPORTED_COMPRESS_ALG + and reg_inf.compressed_alg in SUPORTED_COMPRESSION_TYPES ): return ( urljoin_ensure_base( diff --git a/src/otaclient/app/proto/_ota_metafiles_wrapper.py b/src/ota_metadata/legacy/types.py similarity index 97% rename from src/otaclient/app/proto/_ota_metafiles_wrapper.py rename to src/ota_metadata/legacy/types.py index bca6a24f8..c14807470 100644 --- a/src/otaclient/app/proto/_ota_metafiles_wrapper.py +++ b/src/ota_metadata/legacy/types.py @@ -20,9 +20,8 @@ from pathlib import Path from typing import Union -import ota_metafiles_pb2 as ota_metafiles - -from ._common import MessageWrapper, calculate_slots +from ota_metadata.legacy import ota_metafiles_pb2 as ota_metafiles +from otaclient_common.proto_wrapper import MessageWrapper, calculate_slots # helper mixin diff --git a/src/ota_proxy/cache_control.py b/src/ota_proxy/cache_control.py index 2ec4ceeca..3f5548009 100644 --- a/src/ota_proxy/cache_control.py +++ b/src/ota_proxy/cache_control.py @@ -18,7 +18,7 @@ from typing_extensions import Self -from otaclient._utils import copy_callable_typehint_to_method +from otaclient_common.typing import copy_callable_typehint_to_method _FIELDS = "_fields" diff --git a/src/ota_proxy/server_app.py b/src/ota_proxy/server_app.py index 2349ff284..0ab4e467c 100644 --- a/src/ota_proxy/server_app.py +++ b/src/ota_proxy/server_app.py @@ -22,7 +22,7 @@ import aiohttp -from otaclient._utils.logging import BurstSuppressFilter +from otaclient_common.logging import BurstSuppressFilter from ._consts import ( BHEADER_AUTHORIZATION, diff --git a/src/otaclient/app/boot_control/_common.py b/src/otaclient/app/boot_control/_common.py index 7064a31e6..9f146f773 100644 --- a/src/otaclient/app/boot_control/_common.py +++ b/src/otaclient/app/boot_control/_common.py @@ -24,14 +24,14 @@ from subprocess import CalledProcessError from typing import Callable, Literal, NoReturn, Optional, Union -from ..common import ( +from otaclient.app.configs import config as cfg +from otaclient_api.v2 import types as api_types +from otaclient_common.common import ( read_str_from_file, subprocess_call, subprocess_check_output, write_str_to_file_sync, ) -from ..configs import config as cfg -from ..proto import wrapper logger = logging.getLogger(__name__) @@ -451,18 +451,18 @@ def _load_status_file(self): if _loaded_ota_status is None: logger.info( "ota_status files incompleted/not presented, " - f"initializing and set/store status to {wrapper.StatusOta.INITIALIZED.name}..." + f"initializing and set/store status to {api_types.StatusOta.INITIALIZED.name}..." ) - self._store_current_status(wrapper.StatusOta.INITIALIZED) - self._ota_status = wrapper.StatusOta.INITIALIZED + self._store_current_status(api_types.StatusOta.INITIALIZED) + self._ota_status = api_types.StatusOta.INITIALIZED return logger.info(f"status loaded from file: {_loaded_ota_status.name}") # status except UPDATING and ROLLBACKING(like SUCCESS/FAILURE/ROLLBACK_FAILURE) # are remained as it if _loaded_ota_status not in [ - wrapper.StatusOta.UPDATING, - wrapper.StatusOta.ROLLBACKING, + api_types.StatusOta.UPDATING, + api_types.StatusOta.ROLLBACKING, ]: self._ota_status = _loaded_ota_status return @@ -478,13 +478,13 @@ def _load_status_file(self): # in such case, otaclient will terminate and ota_status will not be updated. if self._is_switching_boot(self.active_slot): if self.finalize_switching_boot(): - self._ota_status = wrapper.StatusOta.SUCCESS - self._store_current_status(wrapper.StatusOta.SUCCESS) + self._ota_status = api_types.StatusOta.SUCCESS + self._store_current_status(api_types.StatusOta.SUCCESS) else: self._ota_status = ( - wrapper.StatusOta.ROLLBACK_FAILURE - if _loaded_ota_status == wrapper.StatusOta.ROLLBACKING - else wrapper.StatusOta.FAILURE + api_types.StatusOta.ROLLBACK_FAILURE + if _loaded_ota_status == api_types.StatusOta.ROLLBACKING + else api_types.StatusOta.FAILURE ) self._store_current_status(self._ota_status) logger.error( @@ -498,9 +498,9 @@ def _load_status_file(self): "this indicates a failed first reboot" ) self._ota_status = ( - wrapper.StatusOta.ROLLBACK_FAILURE - if _loaded_ota_status == wrapper.StatusOta.ROLLBACKING - else wrapper.StatusOta.FAILURE + api_types.StatusOta.ROLLBACK_FAILURE + if _loaded_ota_status == api_types.StatusOta.ROLLBACKING + else api_types.StatusOta.FAILURE ) self._store_current_status(self._ota_status) @@ -545,23 +545,23 @@ def _load_current_slot_in_use(self) -> Optional[str]: # status control - def _store_current_status(self, _status: wrapper.StatusOta): + def _store_current_status(self, _status: api_types.StatusOta): write_str_to_file_sync( self.current_ota_status_dir / cfg.OTA_STATUS_FNAME, _status.name ) - def _store_standby_status(self, _status: wrapper.StatusOta): + def _store_standby_status(self, _status: api_types.StatusOta): write_str_to_file_sync( self.standby_ota_status_dir / cfg.OTA_STATUS_FNAME, _status.name ) - def _load_current_status(self) -> Optional[wrapper.StatusOta]: + def _load_current_status(self) -> Optional[api_types.StatusOta]: if _status_str := read_str_from_file( self.current_ota_status_dir / cfg.OTA_STATUS_FNAME ).upper(): with contextlib.suppress(KeyError): # invalid status string - return wrapper.StatusOta[_status_str] + return api_types.StatusOta[_status_str] # version control @@ -577,8 +577,8 @@ def _is_switching_boot(self, active_slot: str) -> bool: """Detect whether we should switch boot or not with ota_status files.""" # evidence: ota_status _is_updating_or_rollbacking = self._load_current_status() in [ - wrapper.StatusOta.UPDATING, - wrapper.StatusOta.ROLLBACKING, + api_types.StatusOta.UPDATING, + api_types.StatusOta.ROLLBACKING, ] # evidence: slot_in_use @@ -598,7 +598,7 @@ def _is_switching_boot(self, active_slot: str) -> bool: def pre_update_current(self): """On pre_update stage, set current slot's status to FAILURE and set slot_in_use to standby slot.""" - self._store_current_status(wrapper.StatusOta.FAILURE) + self._store_current_status(api_types.StatusOta.FAILURE) self._store_current_slot_in_use(self.standby_slot) def pre_update_standby(self, *, version: str): @@ -610,17 +610,17 @@ def pre_update_standby(self, *, version: str): # create the ota-status folder unconditionally self.standby_ota_status_dir.mkdir(exist_ok=True, parents=True) # store status to standby slot - self._store_standby_status(wrapper.StatusOta.UPDATING) + self._store_standby_status(api_types.StatusOta.UPDATING) self._store_standby_version(version) self._store_standby_slot_in_use(self.standby_slot) def pre_rollback_current(self): - self._store_current_status(wrapper.StatusOta.FAILURE) + self._store_current_status(api_types.StatusOta.FAILURE) def pre_rollback_standby(self): # store ROLLBACKING status to standby self.standby_ota_status_dir.mkdir(exist_ok=True, parents=True) - self._store_standby_status(wrapper.StatusOta.ROLLBACKING) + self._store_standby_status(api_types.StatusOta.ROLLBACKING) def load_active_slot_version(self) -> str: return read_str_from_file( @@ -631,13 +631,13 @@ def load_active_slot_version(self) -> str: def on_failure(self): """Store FAILURE to status file on failure.""" - self._store_current_status(wrapper.StatusOta.FAILURE) + self._store_current_status(api_types.StatusOta.FAILURE) # when standby slot is not created, otastatus is not needed to be set if self.standby_ota_status_dir.is_dir(): - self._store_standby_status(wrapper.StatusOta.FAILURE) + self._store_standby_status(api_types.StatusOta.FAILURE) @property - def booted_ota_status(self) -> wrapper.StatusOta: + def booted_ota_status(self) -> api_types.StatusOta: """Loaded current slot's ota_status during boot control starts. NOTE: distinguish between the live ota_status maintained by otaclient. diff --git a/src/otaclient/app/boot_control/_grub.py b/src/otaclient/app/boot_control/_grub.py index 90766691c..dc01370b8 100644 --- a/src/otaclient/app/boot_control/_grub.py +++ b/src/otaclient/app/boot_control/_grub.py @@ -42,23 +42,23 @@ from subprocess import CalledProcessError from typing import ClassVar, Dict, Generator, List, Optional, Tuple -from .. import errors as ota_errors -from ..common import ( +from otaclient.app import errors as ota_errors +from otaclient.app.boot_control._common import ( + CMDHelperFuncs, + OTAStatusFilesControl, + SlotMountHelper, + cat_proc_cmdline, +) +from otaclient.app.boot_control.configs import grub_cfg as cfg +from otaclient.app.boot_control.protocol import BootControllerProtocol +from otaclient_api.v2 import types as api_types +from otaclient_common.common import ( re_symlink_atomic, read_str_from_file, subprocess_call, subprocess_check_output, write_str_to_file_sync, ) -from ..proto import wrapper -from ._common import ( - CMDHelperFuncs, - OTAStatusFilesControl, - SlotMountHelper, - cat_proc_cmdline, -) -from .configs import grub_cfg as cfg -from .protocol import BootControllerProtocol logger = logging.getLogger(__name__) @@ -865,7 +865,7 @@ def get_standby_boot_dir(self) -> Path: def load_version(self) -> str: return self._ota_status_control.load_active_slot_version() - def get_booted_ota_status(self) -> wrapper.StatusOta: + def get_booted_ota_status(self) -> api_types.StatusOta: return self._ota_status_control.booted_ota_status def on_operation_failure(self): diff --git a/src/otaclient/app/boot_control/_jetson_cboot.py b/src/otaclient/app/boot_control/_jetson_cboot.py index 4d9eb98b9..40b751ab7 100644 --- a/src/otaclient/app/boot_control/_jetson_cboot.py +++ b/src/otaclient/app/boot_control/_jetson_cboot.py @@ -26,12 +26,12 @@ from typing import Generator, Optional from otaclient.app import errors as ota_errors -from otaclient.app.common import subprocess_run_wrapper -from otaclient.app.proto import wrapper - -from ..configs import config as cfg -from ._common import CMDHelperFuncs, OTAStatusFilesControl, SlotMountHelper -from ._jetson_common import ( +from otaclient.app.boot_control._common import ( + CMDHelperFuncs, + OTAStatusFilesControl, + SlotMountHelper, +) +from otaclient.app.boot_control._jetson_common import ( FirmwareBSPVersionControl, NVBootctrlCommon, NVBootctrlTarget, @@ -41,8 +41,14 @@ preserve_ota_config_files_to_standby, update_standby_slot_extlinux_cfg, ) -from .configs import cboot_cfg as boot_cfg -from .protocol import BootControllerProtocol +from otaclient.app.boot_control.configs import cboot_cfg as boot_cfg +from otaclient.app.boot_control.protocol import BootControllerProtocol +from otaclient.app.configs import config as cfg +from otaclient_api.v2 import types as api_types +from otaclient_common.common import subprocess_run_wrapper + +logger = logging.getLogger(__name__) + logger = logging.getLogger(__name__) @@ -617,5 +623,5 @@ def on_operation_failure(self): def load_version(self) -> str: return self._ota_status_control.load_active_slot_version() - def get_booted_ota_status(self) -> wrapper.StatusOta: + def get_booted_ota_status(self) -> api_types.StatusOta: return self._ota_status_control.booted_ota_status diff --git a/src/otaclient/app/boot_control/_jetson_common.py b/src/otaclient/app/boot_control/_jetson_common.py index b01fab29a..b4d3b79b6 100644 --- a/src/otaclient/app/boot_control/_jetson_common.py +++ b/src/otaclient/app/boot_control/_jetson_common.py @@ -29,10 +29,8 @@ from pydantic import BaseModel, BeforeValidator, PlainSerializer from typing_extensions import Annotated, Literal, Self -from otaclient.app.common import write_str_to_file_sync - -from ..common import copytree_identical -from ._common import CMDHelperFuncs +from otaclient.app.boot_control._common import CMDHelperFuncs +from otaclient_common.common import copytree_identical, write_str_to_file_sync logger = logging.getLogger(__name__) diff --git a/src/otaclient/app/boot_control/_rpi_boot.py b/src/otaclient/app/boot_control/_rpi_boot.py index 3c78ae94f..330d7022b 100644 --- a/src/otaclient/app/boot_control/_rpi_boot.py +++ b/src/otaclient/app/boot_control/_rpi_boot.py @@ -23,17 +23,21 @@ from string import Template from typing import Generator -from .. import errors as ota_errors -from ..common import replace_atomic, subprocess_call, subprocess_check_output -from ..proto import wrapper -from ._common import ( +import otaclient.app.errors as ota_errors +from otaclient.app.boot_control._common import ( CMDHelperFuncs, OTAStatusFilesControl, SlotMountHelper, write_str_to_file_sync, ) -from .configs import rpi_boot_cfg as cfg -from .protocol import BootControllerProtocol +from otaclient.app.boot_control.configs import rpi_boot_cfg as cfg +from otaclient.app.boot_control.protocol import BootControllerProtocol +from otaclient_api.v2 import types as api_types +from otaclient_common.common import ( + replace_atomic, + subprocess_call, + subprocess_check_output, +) logger = logging.getLogger(__name__) @@ -375,8 +379,8 @@ def __init__(self) -> None: # 20230613: remove any leftover flag file if ota_status is not UPDATING/ROLLBACKING if self._ota_status_control.booted_ota_status not in ( - wrapper.StatusOta.UPDATING, - wrapper.StatusOta.ROLLBACKING, + api_types.StatusOta.UPDATING, + api_types.StatusOta.ROLLBACKING, ): _flag_file = ( self._rpiboot_control.system_boot_path / cfg.SWITCH_BOOT_FLAG_FILE @@ -546,5 +550,5 @@ def on_operation_failure(self): def load_version(self) -> str: return self._ota_status_control.load_active_slot_version() - def get_booted_ota_status(self) -> wrapper.StatusOta: + def get_booted_ota_status(self) -> api_types.StatusOta: return self._ota_status_control.booted_ota_status diff --git a/src/otaclient/app/boot_control/protocol.py b/src/otaclient/app/boot_control/protocol.py index 9883aee47..d9f579685 100644 --- a/src/otaclient/app/boot_control/protocol.py +++ b/src/otaclient/app/boot_control/protocol.py @@ -17,14 +17,14 @@ from pathlib import Path from typing import Generator, Protocol -from ..proto import wrapper +from otaclient_api.v2 import types as api_types class BootControllerProtocol(Protocol): """Boot controller protocol for otaclient.""" @abstractmethod - def get_booted_ota_status(self) -> wrapper.StatusOta: + def get_booted_ota_status(self) -> api_types.StatusOta: """Get the ota_status loaded from status file during otaclient starts up. This value is meant to be used only once during otaclient starts up, diff --git a/src/otaclient/app/boot_control/selecter.py b/src/otaclient/app/boot_control/selecter.py index 7e08982fa..778ff4f3d 100644 --- a/src/otaclient/app/boot_control/selecter.py +++ b/src/otaclient/app/boot_control/selecter.py @@ -21,9 +21,9 @@ from typing_extensions import deprecated -from ..common import read_str_from_file -from .configs import BootloaderType -from .protocol import BootControllerProtocol +from otaclient.app.boot_control.configs import BootloaderType +from otaclient.app.boot_control.protocol import BootControllerProtocol +from otaclient_common.common import read_str_from_file logger = logging.getLogger(__name__) diff --git a/src/otaclient/app/common.py b/src/otaclient/app/common.py deleted file mode 100644 index 08e874917..000000000 --- a/src/otaclient/app/common.py +++ /dev/null @@ -1,805 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -r"""Utils that shared between modules are listed here.""" - - -from __future__ import annotations - -import itertools -import logging -import os -import shlex -import shutil -import subprocess -import threading -import time -from concurrent.futures import Future, ThreadPoolExecutor -from functools import lru_cache, partial -from hashlib import sha256 -from pathlib import Path -from queue import Queue -from typing import ( - Any, - Callable, - Generator, - Generic, - Iterable, - NamedTuple, - Optional, - Set, - TypeVar, - Union, -) -from urllib.parse import urljoin - -import requests - -from otaclient._utils.linux import ( - ParsedGroup, - ParsedPasswd, - map_gid_by_grpnam, - map_uid_by_pwnam, -) - -from .configs import config as cfg - -logger = logging.getLogger(__name__) - - -def get_backoff(n: int, factor: float, _max: float) -> float: - return min(_max, factor * (2 ** (n - 1))) - - -def wait_with_backoff(_retry_cnt: int, *, _backoff_factor: float, _backoff_max: float): - time.sleep( - get_backoff( - _retry_cnt, - _backoff_factor, - _backoff_max, - ) - ) - - -# file verification -def file_sha256(filename: Union[Path, str]) -> str: - with open(filename, "rb") as f: - m = sha256() - while True: - d = f.read(cfg.LOCAL_CHUNK_SIZE) - if len(d) == 0: - break - m.update(d) - return m.hexdigest() - - -def verify_file(fpath: Path, fhash: str, fsize: Optional[int]) -> bool: - if ( - fpath.is_symlink() - or (not fpath.is_file()) - or (fsize is not None and fpath.stat().st_size != fsize) - ): - return False - return file_sha256(fpath) == fhash - - -# handled file read/write -def read_str_from_file(path: Union[Path, str], *, missing_ok=True, default="") -> str: - """ - Params: - missing_ok: if set to False, FileNotFoundError will be raised to upper - default: the default value to return when missing_ok=True and file not found - """ - try: - return Path(path).read_text().strip() - except FileNotFoundError: - if missing_ok: - return default - - raise - - -def write_str_to_file(path: Path, input: str): - path.write_text(input) - - -def write_str_to_file_sync(path: Union[Path, str], input: str): - with open(path, "w") as f: - f.write(input) - f.flush() - os.fsync(f.fileno()) - - -def subprocess_run_wrapper( - cmd: str | list[str], - *, - check: bool, - check_output: bool, - timeout: Optional[float] = None, -) -> subprocess.CompletedProcess[bytes]: - """A wrapper for subprocess.run method. - - NOTE: this is for the requirement of customized subprocess call - in the future, like chroot or nsenter before execution. - - Args: - cmd (str | list[str]): command to be executed. - check (bool): if True, raise CalledProcessError on non 0 return code. - check_output (bool): if True, the UTF-8 decoded stdout will be returned. - timeout (Optional[float], optional): timeout for execution. Defaults to None. - - Returns: - subprocess.CompletedProcess[bytes]: the result of the execution. - """ - if isinstance(cmd, str): - cmd = shlex.split(cmd) - - return subprocess.run( - cmd, - check=check, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE if check_output else None, - timeout=timeout, - ) - - -def subprocess_check_output( - cmd: str | list[str], - *, - raise_exception: bool = False, - default: str = "", - timeout: Optional[float] = None, -) -> str: - """Run the and return UTF-8 decoded stripped stdout. - - Args: - cmd (str | list[str]): command to be executed. - raise_exception (bool, optional): raise the underlying CalledProcessError. Defaults to False. - default (str, optional): if is False, return on underlying - subprocess call failed. Defaults to "". - timeout (Optional[float], optional): timeout for execution. Defaults to None. - - Returns: - str: UTF-8 decoded stripped stdout. - """ - try: - res = subprocess_run_wrapper( - cmd, check=True, check_output=True, timeout=timeout - ) - return res.stdout.decode().strip() - except subprocess.CalledProcessError as e: - _err_msg = ( - f"command({cmd=}) failed(retcode={e.returncode}: \n" - f"stderr={e.stderr.decode()}" - ) - logger.debug(_err_msg) - - if raise_exception: - raise - return default - - -def subprocess_call( - cmd: str | list[str], - *, - raise_exception: bool = False, - timeout: Optional[float] = None, -) -> None: - """Run the . - - Args: - cmd (str | list[str]): command to be executed. - raise_exception (bool, optional): raise the underlying CalledProcessError. Defaults to False. - timeout (Optional[float], optional): timeout for execution. Defaults to None. - """ - try: - subprocess_run_wrapper(cmd, check=True, check_output=False, timeout=timeout) - except subprocess.CalledProcessError as e: - _err_msg = ( - f"command({cmd=}) failed(retcode={e.returncode}: \n" - f"stderr={e.stderr.decode()}" - ) - logger.debug(_err_msg) - - if raise_exception: - raise - - -def copy_stat(src: Union[Path, str], dst: Union[Path, str]): - """Copy file/dir permission bits and owner info from src to dst.""" - _stat = Path(src).stat() - os.chown(dst, _stat.st_uid, _stat.st_gid) - os.chmod(dst, _stat.st_mode) - - -def copytree_identical(src: Path, dst: Path): - """Recursively copy from the src folder to dst folder. - - Source folder MUST be a dir. - - This function populate files/dirs from the src to the dst, - and make sure the dst is identical to the src. - - By updating the dst folder in-place, we can prevent the case - that the copy is interrupted and the dst is not yet fully populated. - - This function is different from shutil.copytree as follow: - 1. it covers the case that the same path points to different - file type, in this case, the dst path will be clean and - new file/dir will be populated as the src. - 2. it deals with the same symlinks by checking the link target, - re-generate the symlink if the dst symlink is not the same - as the src. - 3. it will remove files that not presented in the src, and - unconditionally override files with same path, ensuring - that the dst will be identical with the src. - - NOTE: is_file/is_dir also returns True if it is a symlink and - the link target is_file/is_dir - """ - if src.is_symlink() or not src.is_dir(): - raise ValueError(f"{src} is not a dir") - - if dst.is_symlink() or not dst.is_dir(): - logger.info(f"{dst=} doesn't exist or not a dir, cleanup and mkdir") - dst.unlink(missing_ok=True) # unlink doesn't follow the symlink - dst.mkdir(mode=src.stat().st_mode, parents=True) - - # phase1: populate files to the dst - for cur_dir, dirs, files in os.walk(src, topdown=True, followlinks=False): - _cur_dir = Path(cur_dir) - _cur_dir_on_dst = dst / _cur_dir.relative_to(src) - - # NOTE(20220803): os.walk now lists symlinks pointed to dir - # in the tuple, we have to handle this behavior - for _dir in dirs: - _src_dir = _cur_dir / _dir - _dst_dir = _cur_dir_on_dst / _dir - if _src_dir.is_symlink(): # this "dir" is a symlink to a dir - if (not _dst_dir.is_symlink()) and _dst_dir.is_dir(): - # if dst is a dir, remove it - shutil.rmtree(_dst_dir, ignore_errors=True) - else: # dst is symlink or file - _dst_dir.unlink() - _dst_dir.symlink_to(os.readlink(_src_dir)) - - # cover the edge case that dst is not a dir. - if _cur_dir_on_dst.is_symlink() or not _cur_dir_on_dst.is_dir(): - _cur_dir_on_dst.unlink(missing_ok=True) - _cur_dir_on_dst.mkdir(parents=True) - copy_stat(_cur_dir, _cur_dir_on_dst) - - # populate files - for fname in files: - _src_f = _cur_dir / fname - _dst_f = _cur_dir_on_dst / fname - - # prepare dst - # src is file but dst is a folder - # delete the dst in advance - if (not _dst_f.is_symlink()) and _dst_f.is_dir(): - # if dst is a dir, remove it - shutil.rmtree(_dst_f, ignore_errors=True) - else: - # dst is symlink or file - _dst_f.unlink(missing_ok=True) - - # copy/symlink dst as src - # if src is symlink, check symlink, re-link if needed - if _src_f.is_symlink(): - _dst_f.symlink_to(os.readlink(_src_f)) - else: - # copy/override src to dst - shutil.copy(_src_f, _dst_f, follow_symlinks=False) - copy_stat(_src_f, _dst_f) - - # phase2: remove unused files in the dst - for cur_dir, dirs, files in os.walk(dst, topdown=True, followlinks=False): - _cur_dir_on_dst = Path(cur_dir) - _cur_dir_on_src = src / _cur_dir_on_dst.relative_to(dst) - - # remove unused dir - if not _cur_dir_on_src.is_dir(): - shutil.rmtree(_cur_dir_on_dst, ignore_errors=True) - dirs.clear() # stop iterate the subfolders of this dir - continue - - # NOTE(20220803): os.walk now lists symlinks pointed to dir - # in the tuple, we have to handle this behavior - for _dir in dirs: - _src_dir = _cur_dir_on_src / _dir - _dst_dir = _cur_dir_on_dst / _dir - if (not _src_dir.is_symlink()) and _dst_dir.is_symlink(): - _dst_dir.unlink() - - for fname in files: - _src_f = _cur_dir_on_src / fname - if not (_src_f.is_symlink() or _src_f.is_file()): - (_cur_dir_on_dst / fname).unlink(missing_ok=True) - - -def re_symlink_atomic(src: Path, target: Union[Path, str]): - """Make the a symlink to atomically. - - If the src is already existed as a file/symlink, - the src will be replaced by the newly created link unconditionally. - - NOTE: os.rename is atomic when src and dst are on - the same filesystem under linux. - NOTE 2: src should not exist or exist as file/symlink. - """ - if not (src.is_symlink() and str(os.readlink(src)) == str(target)): - tmp_link = Path(src).parent / f"tmp_link_{os.urandom(6).hex()}" - try: - tmp_link.symlink_to(target) - os.rename(tmp_link, src) # unconditionally override - except Exception: - tmp_link.unlink(missing_ok=True) - raise - - -def replace_atomic(src: Union[str, Path], dst: Union[str, Path]): - """Atomically replace dst file with src file. - - NOTE: atomic is ensured by os.rename/os.replace under the same filesystem. - """ - src, dst = Path(src), Path(dst) - if not src.is_file(): - raise ValueError(f"{src=} is not a regular file or not exist") - - _tmp_file = dst.parent / f".tmp_{os.urandom(6).hex()}" - try: - # prepare a copy of src file under dst's parent folder - shutil.copy(src, _tmp_file, follow_symlinks=True) - # atomically rename/replace the dst file with the copy - os.replace(_tmp_file, dst) - os.sync() - except Exception: - _tmp_file.unlink(missing_ok=True) - raise - - -def urljoin_ensure_base(base: str, url: str): - """ - NOTE: this method ensure the base_url will be preserved. - for example: - base="http://example.com/data", url="path/to/file" - with urljoin, joined url will be "http://example.com/path/to/file", - with this func, joined url will be "http://example.com/data/path/to/file" - """ - return urljoin(f"{base.rstrip('/')}/", url) - - -# ------ RetryTaskMap ------ # - -T = TypeVar("T") - - -class DoneTask(NamedTuple): - fut: Future - entry: Any - - -class RetryTaskMapInterrupted(Exception): - pass - - -class _TaskMap(Generic[T]): - def __init__( - self, - executor: ThreadPoolExecutor, - max_concurrent: int, - backoff_func: Callable[[int], float], - ) -> None: - # task dispatch interval for continues failling - self.started = False # can only be started once - self._backoff_func = backoff_func - self._executor = executor - self._shutdown_event = threading.Event() - self._se = threading.Semaphore(max_concurrent) - - self._total_tasks_count = 0 - self._dispatched_tasks: Set[Future] = set() - self._failed_tasks: Set[T] = set() - self._last_failed_fut: Optional[Future] = None - - # NOTE: itertools.count is only thread-safe in CPython with GIL, - # as itertools.count is pure C implemented, calling next over - # it is atomic in Python level. - self._done_task_counter = itertools.count(start=1) - self._all_done = threading.Event() - self._dispatch_done = False - - self._done_que: Queue[DoneTask] = Queue() - - def _done_task_cb(self, item: T, fut: Future): - """ - Tracking done counting, set all_done event. - add failed to failed list. - """ - self._se.release() # always release se first - # NOTE: don't change dispatched_tasks if shutdown_event is set - if self._shutdown_event.is_set(): - return - - self._dispatched_tasks.discard(fut) - # check if we finish all tasks - _done_task_num = next(self._done_task_counter) - if self._dispatch_done and _done_task_num == self._total_tasks_count: - logger.debug("all done!") - self._all_done.set() - - if fut.exception(): - self._failed_tasks.add(item) - self._last_failed_fut = fut - self._done_que.put_nowait(DoneTask(fut, item)) - - def _task_dispatcher(self, func: Callable[[T], Any], _iter: Iterable[T]): - """A dispatcher in a dedicated thread that dispatches - tasks to threadpool.""" - for item in _iter: - if self._shutdown_event.is_set(): - return - self._se.acquire() - self._total_tasks_count += 1 - - fut = self._executor.submit(func, item) - fut.add_done_callback(partial(self._done_task_cb, item)) - self._dispatched_tasks.add(fut) - logger.debug(f"dispatcher done: {self._total_tasks_count=}") - self._dispatch_done = True - - def _done_task_collector(self) -> Generator[DoneTask, None, None]: - """A generator for caller to yield done task from.""" - _count = 0 - while not self._shutdown_event.is_set(): - if self._all_done.is_set() and _count == self._total_tasks_count: - logger.debug("collector done!") - return - - yield self._done_que.get() - _count += 1 - - def map(self, func: Callable[[T], Any], _iter: Iterable[T]): - if self.started: - raise ValueError(f"{self.__class__} inst can only be started once") - self.started = True - - self._task_dispatcher_fut = self._executor.submit( - self._task_dispatcher, func, _iter - ) - self._task_collector_gen = self._done_task_collector() - return self._task_collector_gen - - def shutdown(self, *, raise_last_exc=False) -> Optional[Set[T]]: - """Set the shutdown event, and cancal/cleanup ongoing tasks.""" - if not self.started or self._shutdown_event.is_set(): - return - - self._shutdown_event.set() - self._task_collector_gen.close() - # wait for dispatch to stop - self._task_dispatcher_fut.result() - - # cancel all the dispatched tasks - for fut in self._dispatched_tasks: - fut.cancel() - self._dispatched_tasks.clear() - - if not self._failed_tasks: - return - try: - if self._last_failed_fut: - _exc = self._last_failed_fut.exception() - _err_msg = f"{len(self._failed_tasks)=}, last failed: {_exc!r}" - if raise_last_exc: - raise RetryTaskMapInterrupted(_err_msg) from _exc - else: - logger.warning(_err_msg) - return self._failed_tasks.copy() - finally: - # be careful not to create ref cycle here - self._failed_tasks.clear() - _exc, self = None, None - - -class RetryTaskMap(Generic[T]): - def __init__( - self, - *, - backoff_func: Callable[[int], float], - max_retry: int, - max_concurrent: int, - max_workers: Optional[int] = None, - ) -> None: - self._running_inst: Optional[_TaskMap] = None - self._map_gen: Optional[Generator] = None - - self._backoff_func = backoff_func - self._retry_counter = range(max_retry) if max_retry else itertools.count() - self._max_concurrent = max_concurrent - self._max_workers = max_workers - self._executor = ThreadPoolExecutor(max_workers=self._max_workers) - - def map( - self, _func: Callable[[T], Any], _iter: Iterable[T] - ) -> Generator[DoneTask, None, None]: - retry_round = 0 - for retry_round in self._retry_counter: - self._running_inst = _inst = _TaskMap( - self._executor, self._max_concurrent, self._backoff_func - ) - logger.debug(f"{retry_round=} started") - - yield from _inst.map(_func, _iter) - - # this retry round ends, check overall result - if _failed_list := _inst.shutdown(raise_last_exc=False): - _iter = _failed_list # feed failed to next round - # deref before entering sleep - self._running_inst, _inst = None, None - - logger.warning(f"retry#{retry_round+1}: retry on {len(_failed_list)=}") - time.sleep(self._backoff_func(retry_round)) - else: # all tasks finished successfully - self._running_inst, _inst = None, None - return - try: - raise RetryTaskMapInterrupted(f"exceed try limit: {retry_round}") - finally: - # cleanup the defs - _func, _iter = None, None # type: ignore - - def shutdown(self, *, raise_last_exc: bool): - try: - logger.debug("shutdown retry task map") - if self._running_inst: - self._running_inst.shutdown(raise_last_exc=raise_last_exc) - # NOTE: passthrough the exception from underlying running_inst - finally: - self._running_inst = None - self._executor.shutdown(wait=True) - - -def create_tmp_fname(prefix="tmp", length=6, sep="_") -> str: - return f"{prefix}{sep}{os.urandom(length).hex()}" - - -def ensure_otaproxy_start( - otaproxy_url: str, - *, - interval: float = 1, - connection_timeout: float = 5, - probing_timeout: Optional[float] = None, - warning_interval: int = 3 * 60, # seconds -): - """Loop probing until online or exceed . - - This function will issue a logging.warning every seconds. - - Raises: - A ConnectionError if exceeds . - """ - start_time = int(time.time()) - next_warning = start_time + warning_interval - probing_timeout = ( - probing_timeout if probing_timeout and probing_timeout >= 0 else float("inf") - ) - with requests.Session() as session: - while start_time + probing_timeout > (cur_time := int(time.time())): - try: - resp = session.get(otaproxy_url, timeout=connection_timeout) - resp.close() - return - except Exception as e: # server is not up yet - if cur_time >= next_warning: - logger.warning( - f"otaproxy@{otaproxy_url} is not up after {cur_time - start_time} seconds" - f"it might be something wrong with this otaproxy: {e!r}" - ) - next_warning = next_warning + warning_interval - time.sleep(interval) - raise ConnectionError( - f"failed to ensure connection to {otaproxy_url} in {probing_timeout=}seconds" - ) - - -# -# ------ persist files handling ------ # -# -class PersistFilesHandler: - """Preserving files in persist list from to . - - Files being copied will have mode bits preserved, - and uid/gid preserved with mapping as follow: - - src_uid -> src_name -> dst_name -> dst_uid - src_gid -> src_name -> dst_name -> dst_gid - """ - - def __init__( - self, - src_passwd_file: str | Path, - src_group_file: str | Path, - dst_passwd_file: str | Path, - dst_group_file: str | Path, - *, - src_root: str | Path, - dst_root: str | Path, - ): - self._uid_mapper = lru_cache()( - partial( - self.map_uid_by_pwnam, - src_db=ParsedPasswd(src_passwd_file), - dst_db=ParsedPasswd(dst_passwd_file), - ) - ) - self._gid_mapper = lru_cache()( - partial( - self.map_gid_by_grpnam, - src_db=ParsedGroup(src_group_file), - dst_db=ParsedGroup(dst_group_file), - ) - ) - self._src_root = Path(src_root) - self._dst_root = Path(dst_root) - - @staticmethod - def map_uid_by_pwnam( - *, src_db: ParsedPasswd, dst_db: ParsedPasswd, uid: int - ) -> int: - _mapped_uid = map_uid_by_pwnam(src_db=src_db, dst_db=dst_db, uid=uid) - _usern = src_db._by_uid[uid] - - logger.info(f"{_usern=}: mapping src_{uid=} to {_mapped_uid=}") - return _mapped_uid - - @staticmethod - def map_gid_by_grpnam(*, src_db: ParsedGroup, dst_db: ParsedGroup, gid: int) -> int: - _mapped_gid = map_gid_by_grpnam(src_db=src_db, dst_db=dst_db, gid=gid) - _groupn = src_db._by_gid[gid] - - logger.info(f"{_groupn=}: mapping src_{gid=} to {_mapped_gid=}") - return _mapped_gid - - def _chown_with_mapping( - self, _src_stat: os.stat_result, _dst_path: str | Path - ) -> None: - _src_uid, _src_gid = _src_stat.st_uid, _src_stat.st_gid - try: - _dst_uid = self._uid_mapper(uid=_src_uid) - except ValueError: - logger.warning(f"failed to find mapping for {_src_uid=}, keep unchanged") - _dst_uid = _src_uid - - try: - _dst_gid = self._gid_mapper(gid=_src_gid) - except ValueError: - logger.warning(f"failed to find mapping for {_src_gid=}, keep unchanged") - _dst_gid = _src_gid - os.chown(_dst_path, uid=_dst_uid, gid=_dst_gid, follow_symlinks=False) - - @staticmethod - def _rm_target(_target: Path) -> None: - """Remove target with proper methods.""" - if _target.is_symlink() or _target.is_file(): - return _target.unlink(missing_ok=True) - elif _target.is_dir(): - return shutil.rmtree(_target, ignore_errors=True) - elif _target.exists(): - raise ValueError( - f"{_target} is not normal file/symlink/dir, failed to remove" - ) - - def _prepare_symlink(self, _src_path: Path, _dst_path: Path) -> None: - _dst_path.symlink_to(os.readlink(_src_path)) - # NOTE: to get stat from symlink, using os.stat with follow_symlinks=False - self._chown_with_mapping(os.stat(_src_path, follow_symlinks=False), _dst_path) - - def _prepare_dir(self, _src_path: Path, _dst_path: Path) -> None: - _dst_path.mkdir(exist_ok=True) - - _src_stat = os.stat(_src_path, follow_symlinks=False) - os.chmod(_dst_path, _src_stat.st_mode) - self._chown_with_mapping(_src_stat, _dst_path) - - def _prepare_file(self, _src_path: Path, _dst_path: Path) -> None: - shutil.copy(_src_path, _dst_path, follow_symlinks=False) - - _src_stat = os.stat(_src_path, follow_symlinks=False) - os.chmod(_dst_path, _src_stat.st_mode) - self._chown_with_mapping(_src_stat, _dst_path) - - def _prepare_parent(self, _origin_entry: Path) -> None: - for _parent in reversed(_origin_entry.parents): - _src_parent, _dst_parent = ( - self._src_root / _parent, - self._dst_root / _parent, - ) - if _dst_parent.is_dir(): # keep the origin parent on dst as it - continue - if _dst_parent.is_symlink() or _dst_parent.is_file(): - _dst_parent.unlink(missing_ok=True) - self._prepare_dir(_src_parent, _dst_parent) - continue - if _dst_parent.exists(): - raise ValueError( - f"{_dst_parent=} is not a normal file/symlink/dir, cannot cleanup" - ) - self._prepare_dir(_src_parent, _dst_parent) - - # API - - def preserve_persist_entry( - self, _persist_entry: str | Path, *, src_missing_ok: bool = True - ): - logger.info(f"preserving {_persist_entry}") - # persist_entry in persists.txt must be rooted at / - origin_entry = Path(_persist_entry).relative_to("/") - src_path = self._src_root / origin_entry - dst_path = self._dst_root / origin_entry - - # ------ src is symlink ------ # - # NOTE: always check if symlink first as is_file/is_dir/exists all follow_symlinks - if src_path.is_symlink(): - self._rm_target(dst_path) - self._prepare_parent(origin_entry) - self._prepare_symlink(src_path, dst_path) - return - - # ------ src is file ------ # - if src_path.is_file(): - self._rm_target(dst_path) - self._prepare_parent(origin_entry) - self._prepare_file(src_path, dst_path) - return - - # ------ src is not regular file/symlink/dir ------ # - # we only process normal file/symlink/dir - if src_path.exists() and not src_path.is_dir(): - raise ValueError(f"{src_path=} must be either a file/symlink/dir") - - # ------ src doesn't exist ------ # - if not src_path.exists(): - _err_msg = f"{src_path=} not found" - logger.warning(_err_msg) - if not src_missing_ok: - raise ValueError(_err_msg) - return - - # ------ src is dir ------ # - # dive into src_dir and preserve everything under the src dir - self._prepare_parent(origin_entry) - for src_curdir, dnames, fnames in os.walk(src_path, followlinks=False): - src_cur_dpath = Path(src_curdir) - dst_cur_dpath = self._dst_root / src_cur_dpath.relative_to(self._src_root) - - # ------ prepare current dir itself ------ # - self._rm_target(dst_cur_dpath) - self._prepare_dir(src_cur_dpath, dst_cur_dpath) - - # ------ prepare entries in current dir ------ # - for _fname in fnames: - _src_fpath, _dst_fpath = src_cur_dpath / _fname, dst_cur_dpath / _fname - self._rm_target(_dst_fpath) - if _src_fpath.is_symlink(): - self._prepare_symlink(_src_fpath, _dst_fpath) - continue - self._prepare_file(_src_fpath, _dst_fpath) - - # symlinks to dirs also included in dnames, we must handle it - for _dname in dnames: - _src_dpath, _dst_dpath = src_cur_dpath / _dname, dst_cur_dpath / _dname - if _src_dpath.is_symlink(): - self._rm_target(_dst_dpath) - self._prepare_symlink(_src_dpath, _dst_dpath) diff --git a/src/otaclient/app/configs.py b/src/otaclient/app/configs.py index f7a0ed725..9aa5da3f2 100644 --- a/src/otaclient/app/configs.py +++ b/src/otaclient/app/configs.py @@ -93,14 +93,10 @@ class BaseConfig(_InternalSettings): # ------ otaclient logging setting ------ # DEFAULT_LOG_LEVEL = INFO LOG_LEVEL_TABLE: Dict[str, int] = { - "otaclient.app.boot_control.cboot": INFO, - "otaclient.app.boot_control.grub": INFO, - "otaclient.app.ota_client": INFO, - "otaclient.app.ota_client_service": INFO, - "otaclient.app.ota_client_stub": INFO, - "otaclient.app.ota_metadata": INFO, - "otaclient.app.downloader": INFO, - "otaclient.app.main": INFO, + "ota_metadata": INFO, + "otaclient": INFO, + "otaclient_api": INFO, + "otaclient_common": INFO, } LOG_FORMAT = ( "[%(asctime)s][%(levelname)s]-%(name)s:%(funcName)s:%(lineno)d,%(message)s" diff --git a/src/otaclient/app/create_standby/common.py b/src/otaclient/app/create_standby/common.py index 6adb20c6e..eb3df40a2 100644 --- a/src/otaclient/app/create_standby/common.py +++ b/src/otaclient/app/create_standby/common.py @@ -30,10 +30,11 @@ from typing import Any, Dict, Iterator, List, Optional, OrderedDict, Set, Tuple, Union from weakref import WeakKeyDictionary, WeakValueDictionary -from ..common import create_tmp_fname +from ota_metadata.legacy.parser import MetafilesV1, OTAMetadata +from ota_metadata.legacy.types import DirectoryInf, RegularInf +from otaclient_common.common import create_tmp_fname + from ..configs import config as cfg -from ..ota_metadata import MetafilesV1, OTAMetadata -from ..proto.wrapper import DirectoryInf, RegularInf from ..update_stats import ( OTAUpdateStatsCollector, RegInfProcessedStats, diff --git a/src/otaclient/app/create_standby/interface.py b/src/otaclient/app/create_standby/interface.py index 41f5e5709..c66a6123e 100644 --- a/src/otaclient/app/create_standby/interface.py +++ b/src/otaclient/app/create_standby/interface.py @@ -33,7 +33,8 @@ from abc import abstractmethod from typing import Protocol -from ..ota_metadata import OTAMetadata +from ota_metadata.legacy.parser import OTAMetadata + from ..update_stats import OTAUpdateStatsCollector from .common import DeltaBundle diff --git a/src/otaclient/app/create_standby/rebuild_mode.py b/src/otaclient/app/create_standby/rebuild_mode.py index b8f7e0a24..26d7b6a2d 100644 --- a/src/otaclient/app/create_standby/rebuild_mode.py +++ b/src/otaclient/app/create_standby/rebuild_mode.py @@ -21,10 +21,12 @@ from pathlib import Path from typing import List, Set, Tuple -from ..common import RetryTaskMap, get_backoff +from ota_metadata.legacy.parser import MetafilesV1, OTAMetadata +from ota_metadata.legacy.types import RegularInf +from otaclient_common.common import get_backoff +from otaclient_common.retry_task_map import RetryTaskMap + from ..configs import config as cfg -from ..ota_metadata import MetafilesV1, OTAMetadata -from ..proto.wrapper import RegularInf from ..update_stats import ( OTAUpdateStatsCollector, RegInfProcessedStats, diff --git a/src/otaclient/app/errors.py b/src/otaclient/app/errors.py index 83408009f..6a673672e 100644 --- a/src/otaclient/app/errors.py +++ b/src/otaclient/app/errors.py @@ -18,7 +18,7 @@ from enum import Enum, unique from typing import ClassVar -from .proto import wrapper +from otaclient_api.v2 import types as api_types @unique @@ -70,7 +70,7 @@ class OTAError(Exception): ERROR_PREFIX: ClassVar[str] = "E" - failure_type: wrapper.FailureType = wrapper.FailureType.RECOVERABLE + failure_type: api_types.FailureType = api_types.FailureType.RECOVERABLE failure_errcode: OTAErrorCode = OTAErrorCode.E_UNSPECIFIC failure_description: str = "no description available for this error" @@ -118,7 +118,7 @@ def get_error_report(self, title: str = "") -> str: class NetworkError(OTAError): """Generic network error""" - failure_type: wrapper.FailureType = wrapper.FailureType.RECOVERABLE + failure_type: api_types.FailureType = api_types.FailureType.RECOVERABLE failure_errcode: OTAErrorCode = OTAErrorCode.E_NETWORK failure_description: str = _NETWORK_ERR_DEFAULT_DESC @@ -141,7 +141,7 @@ class OTAMetaDownloadFailed(NetworkError): class OTAErrorRecoverable(OTAError): - failure_type: wrapper.FailureType = wrapper.FailureType.RECOVERABLE + failure_type: api_types.FailureType = api_types.FailureType.RECOVERABLE failure_errcode: OTAErrorCode = OTAErrorCode.E_OTA_ERR_RECOVERABLE failure_description: str = _RECOVERABLE_DEFAULT_DESC @@ -168,7 +168,7 @@ class InvalidStatusForOTARollback(OTAErrorRecoverable): class OTAErrorUnrecoverable(OTAError): - failure_type: wrapper.FailureType = wrapper.FailureType.RECOVERABLE + failure_type: api_types.FailureType = api_types.FailureType.RECOVERABLE failure_errcode: OTAErrorCode = OTAErrorCode.E_OTA_ERR_UNRECOVERABLE failure_description: str = _UNRECOVERABLE_DEFAULT_DESC diff --git a/src/otaclient/app/interface.py b/src/otaclient/app/interface.py index cda4e0cfe..fe8bfc794 100644 --- a/src/otaclient/app/interface.py +++ b/src/otaclient/app/interface.py @@ -16,9 +16,10 @@ from abc import abstractmethod from typing import Protocol, Type +from otaclient_api.v2 import otaclient_v2_pb2 as pb2 + from .boot_control.protocol import BootControllerProtocol from .create_standby.interface import StandbySlotCreatorProtocol -from .proto import v2 class OTAClientProtocol(Protocol): @@ -44,4 +45,4 @@ def update( def rollback(self) -> None: ... @abstractmethod - def status(self) -> v2.StatusResponseEcu: ... + def status(self) -> pb2.StatusResponseEcu: ... diff --git a/src/otaclient/app/main.py b/src/otaclient/app/main.py index 63855d9c4..3d99d502e 100644 --- a/src/otaclient/app/main.py +++ b/src/otaclient/app/main.py @@ -13,20 +13,27 @@ # limitations under the License. +from __future__ import annotations + import asyncio import logging import os import sys from pathlib import Path -from otaclient import __version__ # type: ignore +import grpc.aio -from .common import read_str_from_file, write_str_to_file_sync -from .configs import config as cfg -from .configs import ecu_info -from .log_setting import configure_logging -from .ota_client_service import launch_otaclient_grpc_server -from .proto import ota_metafiles, v2, v2_grpc, wrapper # noqa: F401 +# NOTE: as otaclient_api and ota_metadata are using dynamic module import, +# we need to import them before any other otaclient modules. +import ota_metadata.legacy # noqa: F401 +from otaclient import __version__ +from otaclient.app.configs import config as cfg +from otaclient.app.configs import ecu_info, server_cfg +from otaclient.app.log_setting import configure_logging +from otaclient.app.ota_client_stub import OTAClientServiceStub +from otaclient_api.v2 import otaclient_v2_pb2_grpc as v2_grpc +from otaclient_api.v2.api_stub import OtaClientServiceV2 +from otaclient_common.common import read_str_from_file, write_str_to_file_sync # configure logging before any code being executed configure_logging() @@ -52,6 +59,24 @@ def _check_other_otaclient(): write_str_to_file_sync(cfg.OTACLIENT_PID_FILE, f"{os.getpid()}") +def create_otaclient_grpc_server(): + service_stub = OTAClientServiceStub() + ota_client_service_v2 = OtaClientServiceV2(service_stub) + + server = grpc.aio.server() + v2_grpc.add_OtaClientServiceServicer_to_server( + server=server, servicer=ota_client_service_v2 + ) + server.add_insecure_port(f"{ecu_info.ip_addr}:{server_cfg.SERVER_PORT}") + return server + + +async def launch_otaclient_grpc_server(): + server = create_otaclient_grpc_server() + await server.start() + await server.wait_for_termination() + + def main(): logger.info("started") logger.info(f"otaclient version: {__version__}") diff --git a/src/otaclient/app/ota_client.py b/src/otaclient/app/ota_client.py index fa7eed016..8ef78728b 100644 --- a/src/otaclient/app/ota_client.py +++ b/src/otaclient/app/ota_client.py @@ -29,37 +29,55 @@ from typing import Iterator, Optional, Type from urllib.parse import urlparse -from . import downloader +from ota_metadata.legacy import parser as ota_metadata_parser +from ota_metadata.legacy import types as ota_metadata_types +from otaclient import __version__ +from otaclient_api.v2 import types as api_types +from otaclient_common import downloader +from otaclient_common.common import ensure_otaproxy_start, get_backoff +from otaclient_common.persist_file_handling import PersistFilesHandler +from otaclient_common.retry_task_map import RetryTaskMap, RetryTaskMapInterrupted + from . import errors as ota_errors -from . import ota_metadata from .boot_control import BootControllerProtocol, get_boot_controller -from .common import ( - PersistFilesHandler, - RetryTaskMap, - RetryTaskMapInterrupted, - ensure_otaproxy_start, - get_backoff, -) from .configs import config as cfg from .configs import ecu_info from .create_standby import StandbySlotCreatorProtocol, get_standby_slot_creator from .interface import OTAClientProtocol -from .ota_status import LiveOTAStatus -from .proto import wrapper from .update_stats import ( OTAUpdateStatsCollector, RegInfProcessedStats, RegProcessOperation, ) -try: - from otaclient import __version__ # type: ignore -except ImportError: - __version__ = "unknown" - logger = logging.getLogger(__name__) +class LiveOTAStatus: + def __init__(self, ota_status: api_types.StatusOta) -> None: + self.live_ota_status = ota_status + + def get_ota_status(self) -> api_types.StatusOta: + return self.live_ota_status + + def set_ota_status(self, _status: api_types.StatusOta): + self.live_ota_status = _status + + def request_update(self) -> bool: + return self.live_ota_status in [ + api_types.StatusOta.INITIALIZED, + api_types.StatusOta.SUCCESS, + api_types.StatusOta.FAILURE, + api_types.StatusOta.ROLLBACK_FAILURE, + ] + + def request_rollback(self) -> bool: + return self.live_ota_status in [ + api_types.StatusOta.SUCCESS, + api_types.StatusOta.ROLLBACK_FAILURE, + ] + + class OTAClientControlFlags: """ When self ECU's otaproxy is enabled, all the child ECUs of this ECU @@ -100,12 +118,12 @@ def __init__( self._create_standby_cls = create_standby_cls # init update status - self.update_phase = wrapper.UpdatePhase.INITIALIZING + self.update_phase = api_types.UpdatePhase.INITIALIZING self.update_start_time = 0 self.updating_version: str = "" self.failure_reason = "" # init variables needed for update - self._otameta: ota_metadata.OTAMetadata = None # type: ignore + self._otameta: ota_metadata_parser.OTAMetadata = None # type: ignore self._url_base: str = None # type: ignore # dynamic update status @@ -132,10 +150,12 @@ def __init__( # helper methods - def _download_files(self, download_list: Iterator[wrapper.RegularInf]): + def _download_files(self, download_list: Iterator[ota_metadata_types.RegularInf]): """Download all needed OTA image files indicated by calculated bundle.""" - def _download_file(entry: wrapper.RegularInf) -> RegInfProcessedStats: + def _download_file( + entry: ota_metadata_types.RegularInf, + ) -> RegInfProcessedStats: """Download single OTA image file.""" cur_stat = RegInfProcessedStats(op=RegProcessOperation.DOWNLOAD_REMOTE_COPY) @@ -228,7 +248,7 @@ def _update_standby_slot(self): # --- init standby_slot creator, calculate delta --- # logger.info("start to calculate and prepare delta...") - self.update_phase = wrapper.UpdatePhase.CALCULATING_DELTA + self.update_phase = api_types.UpdatePhase.CALCULATING_DELTA self._standby_slot_creator = self._create_standby_cls( ota_metadata=self._otameta, boot_dir=str(self._boot_controller.get_standby_boot_dir()), @@ -254,7 +274,7 @@ def _update_standby_slot(self): "start to download needed files..." f"total_download_files_size={_delta_bundle.total_download_files_size:,}bytes" ) - self.update_phase = wrapper.UpdatePhase.DOWNLOADING_OTA_FILES + self.update_phase = api_types.UpdatePhase.DOWNLOADING_OTA_FILES try: self._download_files(_delta_bundle.get_download_list()) except downloader.DownloadFailedSpaceNotEnough: @@ -280,7 +300,7 @@ def _update_standby_slot(self): # ------ in_update ------ # logger.info("start to apply changes to standby slot...") - self.update_phase = wrapper.UpdatePhase.APPLYING_UPDATE + self.update_phase = api_types.UpdatePhase.APPLYING_UPDATE try: self._standby_slot_creator.create_standby_slot() except Exception as e: @@ -304,7 +324,7 @@ def _process_persistents(self): ) for _perinf in self._otameta.iter_metafile( - ota_metadata.MetafilesV1.PERSISTENT_FNAME + ota_metadata_parser.MetafilesV1.PERSISTENT_FNAME ): _per_fpath = Path(_perinf.path) @@ -353,7 +373,7 @@ def _execute_update( self._update_stats_collector.start() # ------ init, processing metadata ------ # - self.update_phase = wrapper.UpdatePhase.PROCESSING_METADATA + self.update_phase = api_types.UpdatePhase.PROCESSING_METADATA # parse url_base # unconditionally regulate the url_base _url_base = urlparse(raw_url_base) @@ -388,9 +408,12 @@ def _execute_update( # process metadata.jwt and ota metafiles logger.debug("process metadata.jwt...") try: - self._otameta = ota_metadata.OTAMetadata( + self._otameta = ota_metadata_parser.OTAMetadata( url_base=self._url_base, downloader=self._downloader, + run_dir=Path(cfg.RUN_DIR), + certs_dir=Path(cfg.CERTS_DIR), + download_max_idle_time=cfg.DOWNLOAD_GROUP_INACTIVE_TIMEOUT, ) self.total_files_num = self._otameta.total_files_num self.total_files_size_uncompressed = ( @@ -406,13 +429,13 @@ def _execute_update( _err_msg = f"downloader: failed to save ota metafiles: {e!r}" logger.error(_err_msg) raise ota_errors.OTAErrorUnrecoverable(_err_msg, module=__name__) from e - except ota_metadata.MetadataJWTVerificationFailed as e: + except ota_metadata_parser.MetadataJWTVerificationFailed as e: _err_msg = f"failed to verify metadata.jwt: {e!r}" logger.error(_err_msg) raise ota_errors.MetadataJWTVerficationFailed( _err_msg, module=__name__ ) from e - except ota_metadata.MetadataJWTPayloadInvalid as e: + except ota_metadata_parser.MetadataJWTPayloadInvalid as e: _err_msg = f"metadata.jwt is invalid: {e!r}" logger.error(_err_msg) raise ota_errors.MetadataJWTInvalid(_err_msg, module=__name__) from e @@ -434,7 +457,7 @@ def _execute_update( # ------ post update ------ # logger.info("enter post update phase...") - self.update_phase = wrapper.UpdatePhase.PROCESSING_POSTUPDATE + self.update_phase = api_types.UpdatePhase.PROCESSING_POSTUPDATE # NOTE(20240219): move persist file handling here self._process_persistents() @@ -448,11 +471,11 @@ def _execute_update( # API def shutdown(self): - self.update_phase = wrapper.UpdatePhase.INITIALIZING + self.update_phase = api_types.UpdatePhase.INITIALIZING self._downloader.shutdown() self._update_stats_collector.stop() - def get_update_status(self) -> wrapper.UpdateStatus: + def get_update_status(self) -> api_types.UpdateStatus: """ Returns: A tuple contains the version and the update_progress. @@ -472,13 +495,13 @@ def get_update_status(self) -> wrapper.UpdateStatus: update_progress.total_remove_files_num = self.total_remove_files_num # downloading stats update_progress.downloaded_bytes = self._downloader.downloaded_bytes - update_progress.downloading_elapsed_time = wrapper.Duration( + update_progress.downloading_elapsed_time = api_types.Duration( seconds=self._downloader.downloader_active_seconds ) # update other information update_progress.phase = self.update_phase - update_progress.total_elapsed_time = wrapper.Duration.from_nanoseconds( + update_progress.total_elapsed_time = api_types.Duration.from_nanoseconds( time.time_ns() - self.update_start_time ) return update_progress @@ -562,7 +585,7 @@ def __init__( self._rollback_executor: _OTARollbacker = None # type: ignore # err record - self.last_failure_type = wrapper.FailureType.NO_FAILURE + self.last_failure_type = api_types.FailureType.NO_FAILURE self.last_failure_reason = "" self.last_failure_traceback = "" except Exception as e: @@ -570,7 +593,7 @@ def __init__( logger.error(_err_msg) raise ota_errors.OTAClientStartupFailed(_err_msg, module=__name__) from e - def _on_failure(self, exc: ota_errors.OTAError, ota_status: wrapper.StatusOta): + def _on_failure(self, exc: ota_errors.OTAError, ota_status: api_types.StatusOta): self.live_ota_status.set_ota_status(ota_status) try: self.last_failure_type = exc.failure_type @@ -598,15 +621,15 @@ def update(self, version: str, url_base: str, cookies_json: str): ) # reset failure information on handling new update request - self.last_failure_type = wrapper.FailureType.NO_FAILURE + self.last_failure_type = api_types.FailureType.NO_FAILURE self.last_failure_reason = "" self.last_failure_traceback = "" # enter update - self.live_ota_status.set_ota_status(wrapper.StatusOta.UPDATING) + self.live_ota_status.set_ota_status(api_types.StatusOta.UPDATING) self._update_executor.execute(version, url_base, cookies_json) except ota_errors.OTAError as e: - self._on_failure(e, wrapper.StatusOta.FAILURE) + self._on_failure(e, api_types.StatusOta.FAILURE) finally: self._update_executor = None # type: ignore gc.collect() # trigger a forced gc @@ -625,16 +648,16 @@ def rollback(self): ) # clear failure information on handling new rollback request - self.last_failure_type = wrapper.FailureType.NO_FAILURE + self.last_failure_type = api_types.FailureType.NO_FAILURE self.last_failure_reason = "" self.last_failure_traceback = "" # entering rollback - self.live_ota_status.set_ota_status(wrapper.StatusOta.ROLLBACKING) + self.live_ota_status.set_ota_status(api_types.StatusOta.ROLLBACKING) self._rollback_executor.execute() # silently ignore overlapping request except ota_errors.OTAError as e: - self._on_failure(e, wrapper.StatusOta.ROLLBACK_FAILURE) + self._on_failure(e, api_types.StatusOta.ROLLBACK_FAILURE) finally: self._rollback_executor = None # type: ignore self._lock.release() @@ -643,9 +666,9 @@ def rollback(self): "ignore incoming rollback request as local update/rollback is ongoing" ) - def status(self) -> wrapper.StatusResponseEcuV2: + def status(self) -> api_types.StatusResponseEcuV2: live_ota_status = self.live_ota_status.get_ota_status() - status_report = wrapper.StatusResponseEcuV2( + status_report = api_types.StatusResponseEcuV2( ecu_id=self.my_ecu_id, firmware_version=self.current_version, otaclient_version=self.OTACLIENT_VERSION, @@ -654,7 +677,7 @@ def status(self) -> wrapper.StatusResponseEcuV2: failure_reason=self.last_failure_reason, failure_traceback=self.last_failure_traceback, ) - if live_ota_status == wrapper.StatusOta.UPDATING and self._update_executor: + if live_ota_status == api_types.StatusOta.UPDATING and self._update_executor: status_report.update_status = self._update_executor.get_update_status() return status_report @@ -671,15 +694,15 @@ def __init__( self.ecu_id = ecu_info.ecu_id self.otaclient_version = otaclient_version self.local_used_proxy_url = proxy - self.last_operation: Optional[wrapper.StatusOta] = None + self.last_operation: Optional[api_types.StatusOta] = None # default boot startup failure if boot_controller/otaclient_core crashed without # raising specific error - self._otaclient_startup_failed_status = wrapper.StatusResponseEcuV2( + self._otaclient_startup_failed_status = api_types.StatusResponseEcuV2( ecu_id=ecu_info.ecu_id, otaclient_version=otaclient_version, - ota_status=wrapper.StatusOta.FAILURE, - failure_type=wrapper.FailureType.UNRECOVERABLE, + ota_status=api_types.StatusOta.FAILURE, + failure_type=api_types.FailureType.UNRECOVERABLE, failure_reason="unspecific error", ) self._update_rollback_lock = asyncio.Lock() @@ -703,11 +726,11 @@ def __init__( logger.error( e.get_error_report(title=f"boot controller startup failed: {e!r}") ) - self._otaclient_startup_failed_status = wrapper.StatusResponseEcuV2( + self._otaclient_startup_failed_status = api_types.StatusResponseEcuV2( ecu_id=ecu_info.ecu_id, otaclient_version=otaclient_version, - ota_status=wrapper.StatusOta.FAILURE, - failure_type=wrapper.FailureType.UNRECOVERABLE, + ota_status=api_types.StatusOta.FAILURE, + failure_type=api_types.FailureType.UNRECOVERABLE, failure_reason=e.get_failure_reason(), ) @@ -730,11 +753,11 @@ def __init__( logger.error( e.get_error_report(title=f"otaclient core startup failed: {e!r}") ) - self._otaclient_startup_failed_status = wrapper.StatusResponseEcuV2( + self._otaclient_startup_failed_status = api_types.StatusResponseEcuV2( ecu_id=ecu_info.ecu_id, otaclient_version=otaclient_version, - ota_status=wrapper.StatusOta.FAILURE, - failure_type=wrapper.FailureType.UNRECOVERABLE, + ota_status=api_types.StatusOta.FAILURE, + failure_type=api_types.FailureType.UNRECOVERABLE, failure_reason=e.get_failure_reason(), ) @@ -749,12 +772,12 @@ def is_busy(self) -> bool: return self._update_rollback_lock.locked() async def dispatch_update( - self, request: wrapper.UpdateRequestEcu - ) -> wrapper.UpdateResponseEcu: + self, request: api_types.UpdateRequestEcu + ) -> api_types.UpdateResponseEcu: # prevent update operation if otaclient is not started if self._otaclient_inst is None: - return wrapper.UpdateResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.UNRECOVERABLE + return api_types.UpdateResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.UNRECOVERABLE ) # check and acquire lock @@ -762,13 +785,13 @@ async def dispatch_update( logger.warning( f"ongoing operation: {self.last_operation=}, ignore incoming {request=}" ) - return wrapper.UpdateResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.RECOVERABLE + return api_types.UpdateResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.RECOVERABLE ) # immediately take the lock if not locked await self._update_rollback_lock.acquire() - self.last_operation = wrapper.StatusOta.UPDATING + self.last_operation = api_types.StatusOta.UPDATING async def _update_task(): if self._otaclient_inst is None: @@ -790,17 +813,17 @@ async def _update_task(): # dispatch update to background asyncio.create_task(_update_task()) - return wrapper.UpdateResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.NO_FAILURE + return api_types.UpdateResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.NO_FAILURE ) async def dispatch_rollback( - self, request: wrapper.RollbackRequestEcu - ) -> wrapper.RollbackResponseEcu: + self, request: api_types.RollbackRequestEcu + ) -> api_types.RollbackResponseEcu: # prevent rollback operation if otaclient is not started if self._otaclient_inst is None: - return wrapper.RollbackResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.UNRECOVERABLE + return api_types.RollbackResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.UNRECOVERABLE ) # check and acquire lock @@ -808,13 +831,13 @@ async def dispatch_rollback( logger.warning( f"ongoing operation: {self.last_operation=}, ignore incoming {request=}" ) - return wrapper.RollbackResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.RECOVERABLE + return api_types.RollbackResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.RECOVERABLE ) # immediately take the lock if not locked await self._update_rollback_lock.acquire() - self.last_operation = wrapper.StatusOta.ROLLBACKING + self.last_operation = api_types.StatusOta.ROLLBACKING async def _rollback_task(): if self._otaclient_inst is None: @@ -829,11 +852,11 @@ async def _rollback_task(): # dispatch to background asyncio.create_task(_rollback_task()) - return wrapper.RollbackResponseEcu( - ecu_id=self.ecu_id, result=wrapper.FailureType.NO_FAILURE + return api_types.RollbackResponseEcu( + ecu_id=self.ecu_id, result=api_types.FailureType.NO_FAILURE ) - async def get_status(self) -> wrapper.StatusResponseEcuV2: + async def get_status(self) -> api_types.StatusResponseEcuV2: # otaclient is not started due to boot control startup failed if self._otaclient_inst is None: return self._otaclient_startup_failed_status diff --git a/src/otaclient/app/ota_client_service.py b/src/otaclient/app/ota_client_service.py deleted file mode 100644 index d9681817c..000000000 --- a/src/otaclient/app/ota_client_service.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import grpc.aio - -from .configs import ecu_info, server_cfg -from .ota_client_stub import OTAClientServiceStub -from .proto import v2, v2_grpc, wrapper - - -class OtaClientServiceV2(v2_grpc.OtaClientServiceServicer): - def __init__(self, ota_client_stub: OTAClientServiceStub): - self._stub = ota_client_stub - - async def Update(self, request: v2.UpdateRequest, context) -> v2.UpdateResponse: - response = await self._stub.update(wrapper.UpdateRequest.convert(request)) - return response.export_pb() - - async def Rollback( - self, request: v2.RollbackRequest, context - ) -> v2.RollbackResponse: - response = await self._stub.rollback(wrapper.RollbackRequest.convert(request)) - return response.export_pb() - - async def Status(self, request: v2.StatusRequest, context) -> v2.StatusResponse: - response = await self._stub.status(wrapper.StatusRequest.convert(request)) - return response.export_pb() - - -def create_otaclient_grpc_server(): - service_stub = OTAClientServiceStub() - ota_client_service_v2 = OtaClientServiceV2(service_stub) - - server = grpc.aio.server() - v2_grpc.add_OtaClientServiceServicer_to_server( - server=server, servicer=ota_client_service_v2 - ) - server.add_insecure_port(f"{ecu_info.ip_addr}:{server_cfg.SERVER_PORT}") - return server - - -async def launch_otaclient_grpc_server(): - server = create_otaclient_grpc_server() - await server.start() - await server.wait_for_termination() diff --git a/src/otaclient/app/ota_client_stub.py b/src/otaclient/app/ota_client_stub.py index e2cf6c3ef..08bcd0ca9 100644 --- a/src/otaclient/app/ota_client_stub.py +++ b/src/otaclient/app/ota_client_stub.py @@ -31,16 +31,16 @@ from ota_proxy import OTAProxyContextProto from ota_proxy import config as local_otaproxy_cfg from ota_proxy import subprocess_otaproxy_launcher +from otaclient.app import log_setting from otaclient.configs.ecu_info import ECUContact +from otaclient_api.v2 import types as api_types +from otaclient_api.v2.api_caller import ECUNoResponse, OTAClientCall +from otaclient_common.common import ensure_otaproxy_start -from . import log_setting from .boot_control._common import CMDHelperFuncs -from .common import ensure_otaproxy_start from .configs import config as cfg from .configs import ecu_info, proxy_info, server_cfg from .ota_client import OTAClientControlFlags, OTAServicer -from .ota_client_call import ECUNoResponse, OtaClientCall -from .proto import wrapper logger = logging.getLogger(__name__) @@ -348,8 +348,8 @@ def __init__(self) -> None: ecu_info.get_available_ecu_ids() ) - self._all_ecus_status_v2: Dict[str, wrapper.StatusResponseEcuV2] = {} - self._all_ecus_status_v1: Dict[str, wrapper.StatusResponseEcu] = {} + self._all_ecus_status_v2: Dict[str, api_types.StatusResponseEcuV2] = {} + self._all_ecus_status_v1: Dict[str, api_types.StatusResponseEcu] = {} self._all_ecus_last_contact_timestamp: Dict[str, int] = {} # overall ECU status report @@ -511,7 +511,7 @@ async def _loop_updating_properties(self): # API - async def update_from_child_ecu(self, status_resp: wrapper.StatusResponse): + async def update_from_child_ecu(self, status_resp: api_types.StatusResponse): """Update the ECU status storage with child ECU's status report(StatusResponse).""" async with self._writer_lock: self.storage_last_updated_timestamp = cur_timestamp = int(time.time()) @@ -537,7 +537,7 @@ async def update_from_child_ecu(self, status_resp: wrapper.StatusResponse): self._all_ecus_last_contact_timestamp[ecu_id] = cur_timestamp self._all_ecus_status_v2.pop(ecu_id, None) - async def update_from_local_ecu(self, ecu_status: wrapper.StatusResponseEcuV2): + async def update_from_local_ecu(self, ecu_status: api_types.StatusResponseEcuV2): """Update ECU status storage with local ECU's status report(StatusResponseEcuV2).""" async with self._writer_lock: self.storage_last_updated_timestamp = cur_timestamp = int(time.time()) @@ -615,7 +615,7 @@ async def _waiter(): return _waiter - async def export(self) -> wrapper.StatusResponse: + async def export(self) -> api_types.StatusResponse: """Export the contents of this storage to an instance of StatusResponse. NOTE: wrapper.StatusResponse's add_ecu method already takes care of @@ -625,7 +625,7 @@ async def export(self) -> wrapper.StatusResponse: entry in status API response, simulate this behavior by skipping disconnected ECU's status report entry. """ - res = wrapper.StatusResponse() + res = api_types.StatusResponse() async with self._writer_lock: res.available_ecu_ids.extend(self._available_ecu_ids) @@ -678,12 +678,12 @@ async def _polling_direct_subecu_status(self, ecu_contact: ECUContact): """Task entry for loop polling one subECU's status.""" while not self._debug_ecu_status_polling_shutdown_event.is_set(): try: - _ecu_resp = await OtaClientCall.status_call( + _ecu_resp = await OTAClientCall.status_call( ecu_contact.ecu_id, str(ecu_contact.ip_addr), ecu_contact.port, timeout=server_cfg.QUERYING_SUBECU_STATUS_TIMEOUT, - request=wrapper.StatusRequest(), + request=api_types.StatusRequest(), ) await self._ecu_status_storage.update_from_child_ecu(_ecu_resp) except ECUNoResponse as e: @@ -810,10 +810,12 @@ async def _otaclient_control_flags_managing(self): # API stub - async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse: + async def update( + self, request: api_types.UpdateRequest + ) -> api_types.UpdateResponse: logger.info(f"receive update request: {request}") update_acked_ecus = set() - response = wrapper.UpdateResponse() + response = api_types.UpdateResponse() # first: dispatch update request to all directly connected subECUs tasks: Dict[asyncio.Task, ECUContact] = {} @@ -821,7 +823,7 @@ async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse if not request.if_contains_ecu(ecu_contact.ecu_id): continue _task = asyncio.create_task( - OtaClientCall.update_call( + OTAClientCall.update_call( ecu_contact.ecu_id, str(ecu_contact.ip_addr), ecu_contact.port, @@ -834,7 +836,7 @@ async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse done, _ = await asyncio.wait(tasks) for _task in done: try: - _ecu_resp: wrapper.UpdateResponse = _task.result() + _ecu_resp: api_types.UpdateResponse = _task.result() update_acked_ecus.update(_ecu_resp.ecus_acked_update) response.merge_from(_ecu_resp) except ECUNoResponse as e: @@ -847,9 +849,9 @@ async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse # response with RECOVERABLE OTA error for unresponsive # ECU. response.add_ecu( - wrapper.UpdateResponseEcu( + api_types.UpdateResponseEcu( ecu_id=_ecu_contact.ecu_id, - result=wrapper.FailureType.RECOVERABLE, + result=api_types.FailureType.RECOVERABLE, ) ) tasks.clear() @@ -858,7 +860,7 @@ async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse if update_req_ecu := request.find_ecu(self.my_ecu_id): _resp_ecu = await self._otaclient_wrapper.dispatch_update(update_req_ecu) # local otaclient accepts the update request - if _resp_ecu.result == wrapper.FailureType.NO_FAILURE: + if _resp_ecu.result == api_types.FailureType.NO_FAILURE: update_acked_ecus.add(self.my_ecu_id) response.add_ecu(_resp_ecu) @@ -874,10 +876,10 @@ async def update(self, request: wrapper.UpdateRequest) -> wrapper.UpdateResponse return response async def rollback( - self, request: wrapper.RollbackRequest - ) -> wrapper.RollbackResponse: + self, request: api_types.RollbackRequest + ) -> api_types.RollbackResponse: logger.info(f"receive rollback request: {request}") - response = wrapper.RollbackResponse() + response = api_types.RollbackResponse() # first: dispatch rollback request to all directly connected subECUs tasks: Dict[asyncio.Task, ECUContact] = {} @@ -885,7 +887,7 @@ async def rollback( if not request.if_contains_ecu(ecu_contact.ecu_id): continue _task = asyncio.create_task( - OtaClientCall.rollback_call( + OTAClientCall.rollback_call( ecu_contact.ecu_id, str(ecu_contact.ip_addr), ecu_contact.port, @@ -898,7 +900,7 @@ async def rollback( done, _ = await asyncio.wait(tasks) for _task in done: try: - _ecu_resp: wrapper.RollbackResponse = _task.result() + _ecu_resp: api_types.RollbackResponse = _task.result() response.merge_from(_ecu_resp) except ECUNoResponse as e: _ecu_contact = tasks[_task] @@ -910,9 +912,9 @@ async def rollback( # response with RECOVERABLE OTA error for unresponsive # ECU. response.add_ecu( - wrapper.RollbackResponseEcu( + api_types.RollbackResponseEcu( ecu_id=_ecu_contact.ecu_id, - result=wrapper.FailureType.RECOVERABLE, + result=api_types.FailureType.RECOVERABLE, ) ) tasks.clear() @@ -925,5 +927,5 @@ async def rollback( return response - async def status(self, _=None) -> wrapper.StatusResponse: + async def status(self, _=None) -> api_types.StatusResponse: return await self._ecu_status_storage.export() diff --git a/src/otaclient/app/ota_status.py b/src/otaclient/app/ota_status.py deleted file mode 100644 index 53ff85c8f..000000000 --- a/src/otaclient/app/ota_status.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from .proto import wrapper - -logger = logging.getLogger(__name__) - - -class LiveOTAStatus: - def __init__(self, ota_status: wrapper.StatusOta) -> None: - self.live_ota_status = ota_status - - def get_ota_status(self) -> wrapper.StatusOta: - return self.live_ota_status - - def set_ota_status(self, _status: wrapper.StatusOta): - self.live_ota_status = _status - - def request_update(self) -> bool: - return self.live_ota_status in [ - wrapper.StatusOta.INITIALIZED, - wrapper.StatusOta.SUCCESS, - wrapper.StatusOta.FAILURE, - wrapper.StatusOta.ROLLBACK_FAILURE, - ] - - def request_rollback(self) -> bool: - return self.live_ota_status in [ - wrapper.StatusOta.SUCCESS, - wrapper.StatusOta.ROLLBACK_FAILURE, - ] diff --git a/src/otaclient/app/proto/__init__.py b/src/otaclient/app/proto/__init__.py deleted file mode 100644 index 82b5169a5..000000000 --- a/src/otaclient/app/proto/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Packed compiled protobuf files for otaclient.""" -import importlib.util -import sys -from pathlib import Path -from types import ModuleType -from typing import Tuple - -_PROTO_DIR = Path(__file__).parent -# NOTE: order matters here! v2_pb2_grpc depends on v2_pb2 -_FILES_TO_LOAD = [ - _PROTO_DIR / _fname - for _fname in [ - "otaclient_v2_pb2.py", - "otaclient_v2_pb2_grpc.py", - "ota_metafiles_pb2.py", - ] -] - - -def _import_from_file(path: Path) -> Tuple[str, ModuleType]: - if not path.is_file(): - raise ValueError(f"{path} is not a valid module file") - try: - _module_name = path.stem - _spec = importlib.util.spec_from_file_location(_module_name, path) - _module = importlib.util.module_from_spec(_spec) # type: ignore - _spec.loader.exec_module(_module) # type: ignore - return _module_name, _module - except Exception: - raise ImportError(f"failed to import module from {path=}.") - - -def _import_proto(*module_fpaths: Path): - """Import the protobuf modules to path under this folder. - - NOTE: compiled protobuf files under proto folder will be - imported as modules to the global namespace. - """ - for _fpath in module_fpaths: - _module_name, _module = _import_from_file(_fpath) # noqa: F821 - # add the module to the global module namespace - sys.modules[_module_name] = _module - - -_import_proto(*_FILES_TO_LOAD) -del _import_proto, _import_from_file - -import ota_metafiles_pb2 as ota_metafiles # noqa: E402 -import otaclient_v2_pb2 as v2 # noqa: E402 -import otaclient_v2_pb2_grpc as v2_grpc # noqa: E402 - -from . import streamer # noqa: E402 -from . import wrapper # noqa: E402 - -__all__ = ["v2", "v2_grpc", "ota_metafiles", "wrapper", "streamer"] diff --git a/src/otaclient/app/proto/wrapper.py b/src/otaclient/app/proto/wrapper.py deleted file mode 100644 index d7e0d8eb7..000000000 --- a/src/otaclient/app/proto/wrapper.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Modules for registering wrapped compiled protobuf types.""" - - -from ._common import * # noqa: F403, F401 -from ._ota_metafiles_wrapper import * # noqa: F403, F401 -from ._otaclient_v2_pb2_wrapper import * # noqa: F403, F401 diff --git a/src/otaclient/app/update_stats.py b/src/otaclient/app/update_stats.py index c01ee6cda..c4b68fecb 100644 --- a/src/otaclient/app/update_stats.py +++ b/src/otaclient/app/update_stats.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import logging import time from contextlib import contextmanager @@ -20,10 +22,11 @@ from enum import Enum from queue import Empty, Queue from threading import Event, Lock, Thread -from typing import Generator, List +from typing import Generator + +from otaclient_api.v2.types import UpdateStatus from .configs import config as cfg -from .proto.wrapper import UpdateStatus logger = logging.getLogger(__name__) @@ -60,7 +63,7 @@ def __init__(self) -> None: self.collect_interval = cfg.STATS_COLLECT_INTERVAL self.terminated = Event() self._que: Queue[RegInfProcessedStats] = Queue() - self._staging: List[RegInfProcessedStats] = [] + self._staging: list[RegInfProcessedStats] = [] self._collector_thread = None @contextmanager @@ -111,7 +114,7 @@ def get_snapshot(self) -> UpdateStatus: report_download_ota_files = _report report_prepare_local_copy = _report - def report_apply_delta(self, stats_list: List[RegInfProcessedStats]): + def report_apply_delta(self, stats_list: list[RegInfProcessedStats]): """Stats report for APPLY_DELTA operation. Params: diff --git a/src/otaclient/configs/ecu_info.py b/src/otaclient/configs/ecu_info.py index f6fa332bc..08b850bb7 100644 --- a/src/otaclient/configs/ecu_info.py +++ b/src/otaclient/configs/ecu_info.py @@ -26,8 +26,8 @@ from pydantic import AfterValidator, BeforeValidator, Field, IPvAnyAddress from typing_extensions import Annotated -from otaclient._utils.typing import NetworkPort, StrOrPath, gen_strenum_validator from otaclient.configs._common import BaseFixedConfig +from otaclient_common.typing import NetworkPort, StrOrPath, gen_strenum_validator logger = logging.getLogger(__name__) diff --git a/src/otaclient/configs/proxy_info.py b/src/otaclient/configs/proxy_info.py index 3ed4d4c3e..db16a1844 100644 --- a/src/otaclient/configs/proxy_info.py +++ b/src/otaclient/configs/proxy_info.py @@ -26,8 +26,8 @@ from pydantic import AliasChoices, AnyHttpUrl, Field, IPvAnyAddress from pydantic_core import Url -from otaclient._utils.typing import NetworkPort, StrOrPath from otaclient.configs._common import BaseFixedConfig +from otaclient_common.typing import NetworkPort, StrOrPath logger = logging.getLogger(__name__) diff --git a/src/otaclient_api/v2/README.md b/src/otaclient_api/v2/README.md new file mode 100644 index 000000000..7fb5b1cfd --- /dev/null +++ b/src/otaclient_api/v2/README.md @@ -0,0 +1,3 @@ +# OTAClient API version 2 + +Package for holding the protobuf python pb2 generated files and related wrappers and libs for OTAClient API version 2. diff --git a/src/otaclient_api/v2/__init__.py b/src/otaclient_api/v2/__init__.py new file mode 100644 index 000000000..ff3cd5252 --- /dev/null +++ b/src/otaclient_api/v2/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OTAClient API, version 2.""" + + +from __future__ import annotations + +import sys +from pathlib import Path + +from otaclient_common import import_from_file + +# ------ dynamically import pb2 generated code ------ # + +_PROTO_DIR = Path(__file__).parent +# NOTE: order matters here! pb2_grpc depends on pb2 +_FILES_TO_LOAD = [ + _PROTO_DIR / "otaclient_v2_pb2.py", + _PROTO_DIR / "otaclient_v2_pb2_grpc.py", +] +PACKAGE_PREFIX = ".".join(__name__.split(".")[:-1]) + + +def _import_pb2_proto(*module_fpaths: Path): + """Import the protobuf modules to path under this folder. + + NOTE: compiled protobuf files under proto folder will be + imported as modules to the global namespace. + """ + for _fpath in module_fpaths: + _module_name, _module = import_from_file(_fpath) + sys.modules[f"{PACKAGE_PREFIX}.{_module_name}"] = _module + sys.modules[_module_name] = _module + + +_import_pb2_proto(*_FILES_TO_LOAD) +del _import_pb2_proto diff --git a/src/otaclient/app/ota_client_call.py b/src/otaclient_api/v2/api_caller.py similarity index 72% rename from src/otaclient/app/ota_client_call.py rename to src/otaclient_api/v2/api_caller.py index a655232e4..72bc6d38f 100644 --- a/src/otaclient/app/ota_client_call.py +++ b/src/otaclient_api/v2/api_caller.py @@ -11,34 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""OTAClient API caller implementation.""" +from __future__ import annotations + import grpc.aio -from .configs import server_cfg -from .proto import v2_grpc, wrapper +from otaclient_api.v2 import otaclient_v2_pb2_grpc as pb2_grpc +from otaclient_api.v2 import types class ECUNoResponse(Exception): """Raised when ECU cannot response to request on-time.""" -class OtaClientCall: +class OTAClientCall: @staticmethod async def status_call( ecu_id: str, ecu_ipaddr: str, - ecu_port: int = server_cfg.SERVER_PORT, + ecu_port: int, *, - request: wrapper.StatusRequest, + request: types.StatusRequest, timeout=None, - ) -> wrapper.StatusResponse: + ) -> types.StatusResponse: try: ecu_addr = f"{ecu_ipaddr}:{ecu_port}" async with grpc.aio.insecure_channel(ecu_addr) as channel: - stub = v2_grpc.OtaClientServiceStub(channel) + stub = pb2_grpc.OtaClientServiceStub(channel) resp = await stub.Status(request.export_pb(), timeout=timeout) - return wrapper.StatusResponse.convert(resp) + return types.StatusResponse.convert(resp) except Exception as e: _msg = f"{ecu_id=} failed to respond to status request on-time: {e!r}" raise ECUNoResponse(_msg) @@ -47,17 +50,17 @@ async def status_call( async def update_call( ecu_id: str, ecu_ipaddr: str, - ecu_port: int = server_cfg.SERVER_PORT, + ecu_port: int, *, - request: wrapper.UpdateRequest, + request: types.UpdateRequest, timeout=None, - ) -> wrapper.UpdateResponse: + ) -> types.UpdateResponse: try: ecu_addr = f"{ecu_ipaddr}:{ecu_port}" async with grpc.aio.insecure_channel(ecu_addr) as channel: - stub = v2_grpc.OtaClientServiceStub(channel) + stub = pb2_grpc.OtaClientServiceStub(channel) resp = await stub.Update(request.export_pb(), timeout=timeout) - return wrapper.UpdateResponse.convert(resp) + return types.UpdateResponse.convert(resp) except Exception as e: _msg = f"{ecu_id=} failed to respond to update request on-time: {e!r}" raise ECUNoResponse(_msg) @@ -66,17 +69,17 @@ async def update_call( async def rollback_call( ecu_id: str, ecu_ipaddr: str, - ecu_port: int = server_cfg.SERVER_PORT, + ecu_port: int, *, - request: wrapper.RollbackRequest, + request: types.RollbackRequest, timeout=None, - ) -> wrapper.RollbackResponse: + ) -> types.RollbackResponse: try: ecu_addr = f"{ecu_ipaddr}:{ecu_port}" async with grpc.aio.insecure_channel(ecu_addr) as channel: - stub = v2_grpc.OtaClientServiceStub(channel) + stub = pb2_grpc.OtaClientServiceStub(channel) resp = await stub.Rollback(request.export_pb(), timeout=timeout) - return wrapper.RollbackResponse.convert(resp) + return types.RollbackResponse.convert(resp) except Exception as e: _msg = f"{ecu_id=} failed to respond to rollback request on-time: {e!r}" raise ECUNoResponse(_msg) diff --git a/src/otaclient_api/v2/api_stub.py b/src/otaclient_api/v2/api_stub.py new file mode 100644 index 000000000..167eb8f49 --- /dev/null +++ b/src/otaclient_api/v2/api_stub.py @@ -0,0 +1,41 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any + +from otaclient_api.v2 import otaclient_v2_pb2 as pb2 +from otaclient_api.v2 import otaclient_v2_pb2_grpc as pb2_grpc +from otaclient_api.v2 import types + + +class OtaClientServiceV2(pb2_grpc.OtaClientServiceServicer): + def __init__(self, ota_client_stub: Any): + self._stub = ota_client_stub + + async def Update(self, request: pb2.UpdateRequest, context) -> pb2.UpdateResponse: + response = await self._stub.update(types.UpdateRequest.convert(request)) + return response.export_pb() + + async def Rollback( + self, request: pb2.RollbackRequest, context + ) -> pb2.RollbackResponse: + response = await self._stub.rollback(types.RollbackRequest.convert(request)) + return response.export_pb() + + async def Status(self, request: pb2.StatusRequest, context) -> pb2.StatusResponse: + response = await self._stub.status(types.StatusRequest.convert(request)) + return response.export_pb() diff --git a/src/otaclient/app/proto/otaclient_v2_pb2.py b/src/otaclient_api/v2/otaclient_v2_pb2.py similarity index 100% rename from src/otaclient/app/proto/otaclient_v2_pb2.py rename to src/otaclient_api/v2/otaclient_v2_pb2.py diff --git a/src/otaclient/app/proto/otaclient_v2_pb2.pyi b/src/otaclient_api/v2/otaclient_v2_pb2.pyi similarity index 100% rename from src/otaclient/app/proto/otaclient_v2_pb2.pyi rename to src/otaclient_api/v2/otaclient_v2_pb2.pyi diff --git a/src/otaclient/app/proto/otaclient_v2_pb2_grpc.py b/src/otaclient_api/v2/otaclient_v2_pb2_grpc.py similarity index 100% rename from src/otaclient/app/proto/otaclient_v2_pb2_grpc.py rename to src/otaclient_api/v2/otaclient_v2_pb2_grpc.py diff --git a/src/otaclient/app/proto/_otaclient_v2_pb2_wrapper.py b/src/otaclient_api/v2/types.py similarity index 84% rename from src/otaclient/app/proto/_otaclient_v2_pb2_wrapper.py rename to src/otaclient_api/v2/types.py index cc6409d08..f784a568c 100644 --- a/src/otaclient/app/proto/_otaclient_v2_pb2_wrapper.py +++ b/src/otaclient_api/v2/types.py @@ -30,10 +30,10 @@ from typing import TypeVar as _TypeVar from typing import Union as _Union -import otaclient_v2_pb2 as _v2 from typing_extensions import Self -from ._common import ( +from otaclient_api.v2 import otaclient_v2_pb2 as pb2 +from otaclient_common.proto_wrapper import ( Duration, EnumWrapper, MessageWrapper, @@ -123,41 +123,41 @@ def requires_network(self) -> bool: class FailureType(EnumWrapper): - NO_FAILURE = _v2.NO_FAILURE - RECOVERABLE = _v2.RECOVERABLE - UNRECOVERABLE = _v2.UNRECOVERABLE + NO_FAILURE = pb2.NO_FAILURE + RECOVERABLE = pb2.RECOVERABLE + UNRECOVERABLE = pb2.UNRECOVERABLE def to_str(self) -> str: return f"{self.value:0>1}" class StatusOta(EnumWrapper): - INITIALIZED = _v2.INITIALIZED - SUCCESS = _v2.SUCCESS - FAILURE = _v2.FAILURE - UPDATING = _v2.UPDATING - ROLLBACKING = _v2.ROLLBACKING - ROLLBACK_FAILURE = _v2.ROLLBACK_FAILURE + INITIALIZED = pb2.INITIALIZED + SUCCESS = pb2.SUCCESS + FAILURE = pb2.FAILURE + UPDATING = pb2.UPDATING + ROLLBACKING = pb2.ROLLBACKING + ROLLBACK_FAILURE = pb2.ROLLBACK_FAILURE class StatusProgressPhase(EnumWrapper): - INITIAL = _v2.INITIAL - METADATA = _v2.METADATA - DIRECTORY = _v2.DIRECTORY - SYMLINK = _v2.SYMLINK - REGULAR = _v2.REGULAR - PERSISTENT = _v2.PERSISTENT - POST_PROCESSING = _v2.POST_PROCESSING + INITIAL = pb2.INITIAL + METADATA = pb2.METADATA + DIRECTORY = pb2.DIRECTORY + SYMLINK = pb2.SYMLINK + REGULAR = pb2.REGULAR + PERSISTENT = pb2.PERSISTENT + POST_PROCESSING = pb2.POST_PROCESSING class UpdatePhase(EnumWrapper): - INITIALIZING = _v2.INITIALIZING - PROCESSING_METADATA = _v2.PROCESSING_METADATA - CALCULATING_DELTA = _v2.CALCULATING_DELTA - DOWNLOADING_OTA_FILES = _v2.DOWNLOADING_OTA_FILES - APPLYING_UPDATE = _v2.APPLYING_UPDATE - PROCESSING_POSTUPDATE = _v2.PROCESSING_POSTUPDATE - FINALIZING_UPDATE = _v2.FINALIZING_UPDATE + INITIALIZING = pb2.INITIALIZING + PROCESSING_METADATA = pb2.PROCESSING_METADATA + CALCULATING_DELTA = pb2.CALCULATING_DELTA + DOWNLOADING_OTA_FILES = pb2.DOWNLOADING_OTA_FILES + APPLYING_UPDATE = pb2.APPLYING_UPDATE + PROCESSING_POSTUPDATE = pb2.PROCESSING_POSTUPDATE + FINALIZING_UPDATE = pb2.FINALIZING_UPDATE # message wrapper definitions @@ -166,15 +166,15 @@ class UpdatePhase(EnumWrapper): # rollback API -class RollbackRequestEcu(MessageWrapper[_v2.RollbackRequestEcu]): - __slots__ = calculate_slots(_v2.RollbackRequestEcu) +class RollbackRequestEcu(MessageWrapper[pb2.RollbackRequestEcu]): + __slots__ = calculate_slots(pb2.RollbackRequestEcu) ecu_id: str def __init__(self, *, ecu_id: _Optional[str] = ...) -> None: ... -class RollbackRequest(ECUList[RollbackRequestEcu], MessageWrapper[_v2.RollbackRequest]): - __slots__ = calculate_slots(_v2.RollbackRequest) +class RollbackRequest(ECUList[RollbackRequestEcu], MessageWrapper[pb2.RollbackRequest]): + __slots__ = calculate_slots(pb2.RollbackRequest) ecu: RepeatedCompositeContainer[RollbackRequestEcu] def __init__( @@ -182,8 +182,8 @@ def __init__( ) -> None: ... -class RollbackResponseEcu(MessageWrapper[_v2.RollbackResponseEcu]): - __slots__ = calculate_slots(_v2.RollbackRequestEcu) +class RollbackResponseEcu(MessageWrapper[pb2.RollbackResponseEcu]): + __slots__ = calculate_slots(pb2.RollbackRequestEcu) ecu_id: str result: FailureType @@ -196,17 +196,17 @@ def __init__( class RollbackResponse( - ECUList[RollbackResponseEcu], MessageWrapper[_v2.RollbackResponse] + ECUList[RollbackResponseEcu], MessageWrapper[pb2.RollbackResponse] ): - __slots__ = calculate_slots(_v2.RollbackResponse) + __slots__ = calculate_slots(pb2.RollbackResponse) ecu: RepeatedCompositeContainer[RollbackResponseEcu] def __init__( self, *, ecu: _Optional[_Iterable[RollbackResponseEcu]] = ... ) -> None: ... - def merge_from(self, rollback_response: _Union[Self, _v2.RollbackResponse]): - if isinstance(rollback_response, _v2.RollbackResponse): + def merge_from(self, rollback_response: _Union[Self, pb2.RollbackResponse]): + if isinstance(rollback_response, pb2.RollbackResponse): rollback_response = self.__class__.convert(rollback_response) # NOTE, TODO: duplication check is not done self.ecu.extend(rollback_response.ecu) @@ -215,8 +215,8 @@ def merge_from(self, rollback_response: _Union[Self, _v2.RollbackResponse]): # status API -class StatusProgress(MessageWrapper[_v2.StatusProgress]): - __slots__ = calculate_slots(_v2.StatusProgress) +class StatusProgress(MessageWrapper[pb2.StatusProgress]): + __slots__ = calculate_slots(pb2.StatusProgress) download_bytes: int elapsed_time_copy: Duration elapsed_time_download: Duration @@ -263,8 +263,8 @@ def add_elapsed_time(self, _field_name: str, _value: int): _field.add_nanoseconds(_value) -class Status(MessageWrapper[_v2.Status]): - __slots__ = calculate_slots(_v2.Status) +class Status(MessageWrapper[pb2.Status]): + __slots__ = calculate_slots(pb2.Status) failure: FailureType failure_reason: str progress: StatusProgress @@ -282,12 +282,12 @@ def __init__( ) -> None: ... -class StatusRequest(MessageWrapper[_v2.StatusRequest]): - __slots__ = calculate_slots(_v2.StatusRequest) +class StatusRequest(MessageWrapper[pb2.StatusRequest]): + __slots__ = calculate_slots(pb2.StatusRequest) -class StatusResponseEcu(ECUStatusSummary, MessageWrapper[_v2.StatusResponseEcu]): - __slots__ = calculate_slots(_v2.StatusResponseEcu) +class StatusResponseEcu(ECUStatusSummary, MessageWrapper[pb2.StatusResponseEcu]): + __slots__ = calculate_slots(pb2.StatusResponseEcu) ecu_id: str result: FailureType status: Status @@ -334,8 +334,8 @@ def requires_network(self) -> bool: } -class UpdateStatus(MessageWrapper[_v2.UpdateStatus]): - __slots__ = calculate_slots(_v2.UpdateStatus) +class UpdateStatus(MessageWrapper[pb2.UpdateStatus]): + __slots__ = calculate_slots(pb2.UpdateStatus) delta_generating_elapsed_time: Duration downloaded_bytes: int downloaded_files_num: int @@ -422,8 +422,8 @@ def convert_to_v1_StatusProgress(self) -> StatusProgress: return _res -class StatusResponseEcuV2(ECUStatusSummary, MessageWrapper[_v2.StatusResponseEcuV2]): - __slots__ = calculate_slots(_v2.StatusResponseEcuV2) +class StatusResponseEcuV2(ECUStatusSummary, MessageWrapper[pb2.StatusResponseEcuV2]): + __slots__ = calculate_slots(pb2.StatusResponseEcuV2) ecu_id: str failure_reason: str failure_traceback: str @@ -483,9 +483,9 @@ def requires_network(self) -> bool: class StatusResponse( ECUV2List[StatusResponseEcuV2], ECUList[StatusResponseEcu], - MessageWrapper[_v2.StatusResponse], + MessageWrapper[pb2.StatusResponse], ): - __slots__ = calculate_slots(_v2.StatusResponse) + __slots__ = calculate_slots(pb2.StatusResponse) available_ecu_ids: RepeatedScalarContainer[str] ecu: RepeatedCompositeContainer[StatusResponseEcu] ecu_v2: RepeatedCompositeContainer[StatusResponseEcuV2] @@ -502,20 +502,20 @@ def add_ecu(self, _response_ecu: Any): if isinstance(_response_ecu, StatusResponseEcuV2): self.ecu_v2.append(_response_ecu) self.ecu.append(_response_ecu.convert_to_v1()) # v1 compat - elif isinstance(_response_ecu, _v2.StatusResponseEcuV2): + elif isinstance(_response_ecu, pb2.StatusResponseEcuV2): _converted = StatusResponseEcuV2.convert(_response_ecu) self.ecu_v2.append(_response_ecu) self.ecu.append(_converted.convert_to_v1()) # v1 compat # v1 elif isinstance(_response_ecu, StatusResponseEcu): self.ecu.append(_response_ecu) - elif isinstance(_response_ecu, _v2.StatusResponseEcu): + elif isinstance(_response_ecu, pb2.StatusResponseEcu): self.ecu.append(StatusResponseEcu.convert(_response_ecu)) else: raise TypeError - def merge_from(self, status_resp: _Union[Self, _v2.StatusResponse]): - if isinstance(status_resp, _v2.StatusResponse): + def merge_from(self, status_resp: _Union[Self, pb2.StatusResponse]): + if isinstance(status_resp, pb2.StatusResponse): status_resp = self.__class__.convert(status_resp) # merge ecu only, don't merge available_ecu_ids! # NOTE, TODO: duplication check is not done @@ -526,8 +526,8 @@ def merge_from(self, status_resp: _Union[Self, _v2.StatusResponse]): # update API -class UpdateRequestEcu(MessageWrapper[_v2.UpdateRequestEcu]): - __slots__ = calculate_slots(_v2.UpdateRequestEcu) +class UpdateRequestEcu(MessageWrapper[pb2.UpdateRequestEcu]): + __slots__ = calculate_slots(pb2.UpdateRequestEcu) cookies: str ecu_id: str url: str @@ -543,8 +543,8 @@ def __init__( ) -> None: ... -class UpdateRequest(ECUList[UpdateRequestEcu], MessageWrapper[_v2.UpdateRequest]): - __slots__ = calculate_slots(_v2.UpdateRequest) +class UpdateRequest(ECUList[UpdateRequestEcu], MessageWrapper[pb2.UpdateRequest]): + __slots__ = calculate_slots(pb2.UpdateRequest) ecu: RepeatedCompositeContainer[UpdateRequestEcu] def __init__( @@ -552,8 +552,8 @@ def __init__( ) -> None: ... -class UpdateResponseEcu(MessageWrapper[_v2.UpdateResponseEcu]): - __slots__ = calculate_slots(_v2.UpdateResponseEcu) +class UpdateResponseEcu(MessageWrapper[pb2.UpdateResponseEcu]): + __slots__ = calculate_slots(pb2.UpdateResponseEcu) ecu_id: str result: FailureType @@ -565,8 +565,8 @@ def __init__( ) -> None: ... -class UpdateResponse(ECUList[UpdateResponseEcu], MessageWrapper[_v2.UpdateResponse]): - __slots__ = calculate_slots(_v2.UpdateResponse) +class UpdateResponse(ECUList[UpdateResponseEcu], MessageWrapper[pb2.UpdateResponse]): + __slots__ = calculate_slots(pb2.UpdateResponse) ecu: RepeatedCompositeContainer[UpdateResponseEcu] def __init__( @@ -581,8 +581,8 @@ def ecus_acked_update(self) -> _Set[str]: if ecu_resp.result is FailureType.NO_FAILURE } - def merge_from(self, update_response: _Union[Self, _v2.UpdateResponse]): - if isinstance(update_response, _v2.UpdateResponse): + def merge_from(self, update_response: _Union[Self, pb2.UpdateResponse]): + if isinstance(update_response, pb2.UpdateResponse): update_response = self.__class__.convert(update_response) # NOTE, TODO: duplication check is not done self.ecu.extend(update_response.ecu) diff --git a/src/otaclient/_utils/__init__.py b/src/otaclient_common/__init__.py similarity index 67% rename from src/otaclient/_utils/__init__.py rename to src/otaclient_common/__init__.py index dd75b77f2..a73c58270 100644 --- a/src/otaclient/_utils/__init__.py +++ b/src/otaclient_common/__init__.py @@ -11,46 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Common shared libs for otaclient.""" from __future__ import annotations +import importlib.util import os from math import ceil from pathlib import Path -from typing import Any, Callable, Optional, TypeVar - -from typing_extensions import Concatenate, Literal, ParamSpec - -P = ParamSpec("P") - - -def copy_callable_typehint(_source: Callable[P, Any]): - """This helper function return a decorator that can type hint the target - function as the _source function. - - At runtime, this decorator actually does nothing, but just return the input function as it. - But the returned function will have the same type hint as the source function in ide. - It will not impact the runtime behavior of the decorated function. - """ - - def _decorator(target) -> Callable[P, Any]: - return target - - return _decorator - - -RT = TypeVar("RT") - - -def copy_callable_typehint_to_method(_source: Callable[P, Any]): - """Works the same as copy_callable_typehint, but omit the first arg.""" - - def _decorator(target: Callable[..., RT]) -> Callable[Concatenate[Any, P], RT]: - return target # type: ignore - - return _decorator +from types import ModuleType +from typing import Optional +from typing_extensions import Literal _MultiUnits = Literal["GiB", "MiB", "KiB", "Bytes", "KB", "MB", "GB"] # fmt: off @@ -87,3 +60,16 @@ def replace_root(path: str | Path, old_root: str | Path, new_root: str | Path) - if os.path.commonpath([path, old_root]) != old_root: raise ValueError(f"{old_root=} is not the root of {path=}") return os.path.join(new_root, os.path.relpath(path, old_root)) + + +def import_from_file(path: Path) -> tuple[str, ModuleType]: + if not path.is_file(): + raise ValueError(f"{path} is not a valid module file") + try: + _module_name = path.stem + _spec = importlib.util.spec_from_file_location(_module_name, path) + _module = importlib.util.module_from_spec(_spec) # type: ignore + _spec.loader.exec_module(_module) # type: ignore + return _module_name, _module + except Exception: + raise ImportError(f"failed to import module from {path=}.") diff --git a/src/otaclient_common/common.py b/src/otaclient_common/common.py new file mode 100644 index 000000000..10433bc2f --- /dev/null +++ b/src/otaclient_common/common.py @@ -0,0 +1,404 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils that shared between modules are listed here. + +TODO(20240603): the old otaclient.app.common, split it by functionalities in + the future. +""" + + +from __future__ import annotations + +import logging +import os +import shlex +import shutil +import subprocess +import time +from hashlib import sha256 +from pathlib import Path +from typing import Optional, Union +from urllib.parse import urljoin + +import requests + +logger = logging.getLogger(__name__) + + +def get_backoff(n: int, factor: float, _max: float) -> float: + return min(_max, factor * (2 ** (n - 1))) + + +def wait_with_backoff(_retry_cnt: int, *, _backoff_factor: float, _backoff_max: float): + time.sleep( + get_backoff( + _retry_cnt, + _backoff_factor, + _backoff_max, + ) + ) + + +# file verification +def file_sha256( + filename: Union[Path, str], *, chunk_size: int = 1 * 1024 * 1024 +) -> str: + with open(filename, "rb") as f: + m = sha256() + while True: + d = f.read(chunk_size) + if len(d) == 0: + break + m.update(d) + return m.hexdigest() + + +def verify_file(fpath: Path, fhash: str, fsize: Optional[int]) -> bool: + if ( + fpath.is_symlink() + or (not fpath.is_file()) + or (fsize is not None and fpath.stat().st_size != fsize) + ): + return False + return file_sha256(fpath) == fhash + + +# handled file read/write +def read_str_from_file(path: Union[Path, str], *, missing_ok=True, default="") -> str: + """ + Params: + missing_ok: if set to False, FileNotFoundError will be raised to upper + default: the default value to return when missing_ok=True and file not found + """ + try: + return Path(path).read_text().strip() + except FileNotFoundError: + if missing_ok: + return default + + raise + + +def write_str_to_file(path: Path, input: str): + path.write_text(input) + + +def write_str_to_file_sync(path: Union[Path, str], input: str): + with open(path, "w") as f: + f.write(input) + f.flush() + os.fsync(f.fileno()) + + +def subprocess_run_wrapper( + cmd: str | list[str], + *, + check: bool, + check_output: bool, + timeout: Optional[float] = None, +) -> subprocess.CompletedProcess[bytes]: + """A wrapper for subprocess.run method. + + NOTE: this is for the requirement of customized subprocess call + in the future, like chroot or nsenter before execution. + + Args: + cmd (str | list[str]): command to be executed. + check (bool): if True, raise CalledProcessError on non 0 return code. + check_output (bool): if True, the UTF-8 decoded stdout will be returned. + timeout (Optional[float], optional): timeout for execution. Defaults to None. + + Returns: + subprocess.CompletedProcess[bytes]: the result of the execution. + """ + if isinstance(cmd, str): + cmd = shlex.split(cmd) + + return subprocess.run( + cmd, + check=check, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE if check_output else None, + timeout=timeout, + ) + + +def subprocess_check_output( + cmd: str | list[str], + *, + raise_exception: bool = False, + default: str = "", + timeout: Optional[float] = None, +) -> str: + """Run the and return UTF-8 decoded stripped stdout. + + Args: + cmd (str | list[str]): command to be executed. + raise_exception (bool, optional): raise the underlying CalledProcessError. Defaults to False. + default (str, optional): if is False, return on underlying + subprocess call failed. Defaults to "". + timeout (Optional[float], optional): timeout for execution. Defaults to None. + + Returns: + str: UTF-8 decoded stripped stdout. + """ + try: + res = subprocess_run_wrapper( + cmd, check=True, check_output=True, timeout=timeout + ) + return res.stdout.decode().strip() + except subprocess.CalledProcessError as e: + _err_msg = ( + f"command({cmd=}) failed(retcode={e.returncode}: \n" + f"stderr={e.stderr.decode()}" + ) + logger.debug(_err_msg) + + if raise_exception: + raise + return default + + +def subprocess_call( + cmd: str | list[str], + *, + raise_exception: bool = False, + timeout: Optional[float] = None, +) -> None: + """Run the . + + Args: + cmd (str | list[str]): command to be executed. + raise_exception (bool, optional): raise the underlying CalledProcessError. Defaults to False. + timeout (Optional[float], optional): timeout for execution. Defaults to None. + """ + try: + subprocess_run_wrapper(cmd, check=True, check_output=False, timeout=timeout) + except subprocess.CalledProcessError as e: + _err_msg = ( + f"command({cmd=}) failed(retcode={e.returncode}: \n" + f"stderr={e.stderr.decode()}" + ) + logger.debug(_err_msg) + + if raise_exception: + raise + + +def copy_stat(src: Union[Path, str], dst: Union[Path, str]): + """Copy file/dir permission bits and owner info from src to dst.""" + _stat = Path(src).stat() + os.chown(dst, _stat.st_uid, _stat.st_gid) + os.chmod(dst, _stat.st_mode) + + +def copytree_identical(src: Path, dst: Path): + """Recursively copy from the src folder to dst folder. + + Source folder MUST be a dir. + + This function populate files/dirs from the src to the dst, + and make sure the dst is identical to the src. + + By updating the dst folder in-place, we can prevent the case + that the copy is interrupted and the dst is not yet fully populated. + + This function is different from shutil.copytree as follow: + 1. it covers the case that the same path points to different + file type, in this case, the dst path will be clean and + new file/dir will be populated as the src. + 2. it deals with the same symlinks by checking the link target, + re-generate the symlink if the dst symlink is not the same + as the src. + 3. it will remove files that not presented in the src, and + unconditionally override files with same path, ensuring + that the dst will be identical with the src. + + NOTE: is_file/is_dir also returns True if it is a symlink and + the link target is_file/is_dir + """ + if src.is_symlink() or not src.is_dir(): + raise ValueError(f"{src} is not a dir") + + if dst.is_symlink() or not dst.is_dir(): + logger.info(f"{dst=} doesn't exist or not a dir, cleanup and mkdir") + dst.unlink(missing_ok=True) # unlink doesn't follow the symlink + dst.mkdir(mode=src.stat().st_mode, parents=True) + + # phase1: populate files to the dst + for cur_dir, dirs, files in os.walk(src, topdown=True, followlinks=False): + _cur_dir = Path(cur_dir) + _cur_dir_on_dst = dst / _cur_dir.relative_to(src) + + # NOTE(20220803): os.walk now lists symlinks pointed to dir + # in the tuple, we have to handle this behavior + for _dir in dirs: + _src_dir = _cur_dir / _dir + _dst_dir = _cur_dir_on_dst / _dir + if _src_dir.is_symlink(): # this "dir" is a symlink to a dir + if (not _dst_dir.is_symlink()) and _dst_dir.is_dir(): + # if dst is a dir, remove it + shutil.rmtree(_dst_dir, ignore_errors=True) + else: # dst is symlink or file + _dst_dir.unlink() + _dst_dir.symlink_to(os.readlink(_src_dir)) + + # cover the edge case that dst is not a dir. + if _cur_dir_on_dst.is_symlink() or not _cur_dir_on_dst.is_dir(): + _cur_dir_on_dst.unlink(missing_ok=True) + _cur_dir_on_dst.mkdir(parents=True) + copy_stat(_cur_dir, _cur_dir_on_dst) + + # populate files + for fname in files: + _src_f = _cur_dir / fname + _dst_f = _cur_dir_on_dst / fname + + # prepare dst + # src is file but dst is a folder + # delete the dst in advance + if (not _dst_f.is_symlink()) and _dst_f.is_dir(): + # if dst is a dir, remove it + shutil.rmtree(_dst_f, ignore_errors=True) + else: + # dst is symlink or file + _dst_f.unlink(missing_ok=True) + + # copy/symlink dst as src + # if src is symlink, check symlink, re-link if needed + if _src_f.is_symlink(): + _dst_f.symlink_to(os.readlink(_src_f)) + else: + # copy/override src to dst + shutil.copy(_src_f, _dst_f, follow_symlinks=False) + copy_stat(_src_f, _dst_f) + + # phase2: remove unused files in the dst + for cur_dir, dirs, files in os.walk(dst, topdown=True, followlinks=False): + _cur_dir_on_dst = Path(cur_dir) + _cur_dir_on_src = src / _cur_dir_on_dst.relative_to(dst) + + # remove unused dir + if not _cur_dir_on_src.is_dir(): + shutil.rmtree(_cur_dir_on_dst, ignore_errors=True) + dirs.clear() # stop iterate the subfolders of this dir + continue + + # NOTE(20220803): os.walk now lists symlinks pointed to dir + # in the tuple, we have to handle this behavior + for _dir in dirs: + _src_dir = _cur_dir_on_src / _dir + _dst_dir = _cur_dir_on_dst / _dir + if (not _src_dir.is_symlink()) and _dst_dir.is_symlink(): + _dst_dir.unlink() + + for fname in files: + _src_f = _cur_dir_on_src / fname + if not (_src_f.is_symlink() or _src_f.is_file()): + (_cur_dir_on_dst / fname).unlink(missing_ok=True) + + +def re_symlink_atomic(src: Path, target: Union[Path, str]): + """Make the a symlink to atomically. + + If the src is already existed as a file/symlink, + the src will be replaced by the newly created link unconditionally. + + NOTE: os.rename is atomic when src and dst are on + the same filesystem under linux. + NOTE 2: src should not exist or exist as file/symlink. + """ + if not (src.is_symlink() and str(os.readlink(src)) == str(target)): + tmp_link = Path(src).parent / f"tmp_link_{os.urandom(6).hex()}" + try: + tmp_link.symlink_to(target) + os.rename(tmp_link, src) # unconditionally override + except Exception: + tmp_link.unlink(missing_ok=True) + raise + + +def replace_atomic(src: Union[str, Path], dst: Union[str, Path]): + """Atomically replace dst file with src file. + + NOTE: atomic is ensured by os.rename/os.replace under the same filesystem. + """ + src, dst = Path(src), Path(dst) + if not src.is_file(): + raise ValueError(f"{src=} is not a regular file or not exist") + + _tmp_file = dst.parent / f".tmp_{os.urandom(6).hex()}" + try: + # prepare a copy of src file under dst's parent folder + shutil.copy(src, _tmp_file, follow_symlinks=True) + # atomically rename/replace the dst file with the copy + os.replace(_tmp_file, dst) + os.sync() + except Exception: + _tmp_file.unlink(missing_ok=True) + raise + + +def urljoin_ensure_base(base: str, url: str): + """ + NOTE: this method ensure the base_url will be preserved. + for example: + base="http://example.com/data", url="path/to/file" + with urljoin, joined url will be "http://example.com/path/to/file", + with this func, joined url will be "http://example.com/data/path/to/file" + """ + return urljoin(f"{base.rstrip('/')}/", url) + + +def create_tmp_fname(prefix="tmp", length=6, sep="_") -> str: + return f"{prefix}{sep}{os.urandom(length).hex()}" + + +def ensure_otaproxy_start( + otaproxy_url: str, + *, + interval: float = 1, + connection_timeout: float = 5, + probing_timeout: Optional[float] = None, + warning_interval: int = 3 * 60, # seconds +): + """Loop probing until online or exceed . + + This function will issue a logging.warning every seconds. + + Raises: + A ConnectionError if exceeds . + """ + start_time = int(time.time()) + next_warning = start_time + warning_interval + probing_timeout = ( + probing_timeout if probing_timeout and probing_timeout >= 0 else float("inf") + ) + with requests.Session() as session: + while start_time + probing_timeout > (cur_time := int(time.time())): + try: + resp = session.get(otaproxy_url, timeout=connection_timeout) + resp.close() + return + except Exception as e: # server is not up yet + if cur_time >= next_warning: + logger.warning( + f"otaproxy@{otaproxy_url} is not up after {cur_time - start_time} seconds" + f"it might be something wrong with this otaproxy: {e!r}" + ) + next_warning = next_warning + warning_interval + time.sleep(interval) + raise ConnectionError( + f"failed to ensure connection to {otaproxy_url} in {probing_timeout=}seconds" + ) diff --git a/src/otaclient/app/downloader.py b/src/otaclient_common/downloader.py similarity index 98% rename from src/otaclient/app/downloader.py rename to src/otaclient_common/downloader.py index 173d9b1a3..b67a3371c 100644 --- a/src/otaclient/app/downloader.py +++ b/src/otaclient_common/downloader.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""A common used downloader implementation for otaclient.""" +from __future__ import annotations + import errno import logging import os @@ -49,10 +52,9 @@ from urllib3.util.retry import Retry from ota_proxy import OTAFileCacheControl -from otaclient._utils import copy_callable_typehint - -from .common import wait_with_backoff -from .configs import config as cfg +from otaclient.app.configs import config as cfg +from otaclient_common.common import wait_with_backoff +from otaclient_common.typing import copy_callable_typehint logger = logging.getLogger(__name__) diff --git a/src/otaclient/_utils/linux.py b/src/otaclient_common/linux.py similarity index 100% rename from src/otaclient/_utils/linux.py rename to src/otaclient_common/linux.py diff --git a/src/otaclient/_utils/logging.py b/src/otaclient_common/logging.py similarity index 100% rename from src/otaclient/_utils/logging.py rename to src/otaclient_common/logging.py diff --git a/src/otaclient_common/persist_file_handling.py b/src/otaclient_common/persist_file_handling.py new file mode 100644 index 000000000..6b0178017 --- /dev/null +++ b/src/otaclient_common/persist_file_handling.py @@ -0,0 +1,219 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +import os +import shutil +from functools import lru_cache, partial +from pathlib import Path + +from otaclient_common.linux import ( + ParsedGroup, + ParsedPasswd, + map_gid_by_grpnam, + map_uid_by_pwnam, +) + +logger = logging.getLogger(__name__) + + +class PersistFilesHandler: + """Preserving files in persist list from to . + + Files being copied will have mode bits preserved, + and uid/gid preserved with mapping as follow: + + src_uid -> src_name -> dst_name -> dst_uid + src_gid -> src_name -> dst_name -> dst_gid + """ + + def __init__( + self, + src_passwd_file: str | Path, + src_group_file: str | Path, + dst_passwd_file: str | Path, + dst_group_file: str | Path, + *, + src_root: str | Path, + dst_root: str | Path, + ): + self._uid_mapper = lru_cache()( + partial( + self.map_uid_by_pwnam, + src_db=ParsedPasswd(src_passwd_file), + dst_db=ParsedPasswd(dst_passwd_file), + ) + ) + self._gid_mapper = lru_cache()( + partial( + self.map_gid_by_grpnam, + src_db=ParsedGroup(src_group_file), + dst_db=ParsedGroup(dst_group_file), + ) + ) + self._src_root = Path(src_root) + self._dst_root = Path(dst_root) + + @staticmethod + def map_uid_by_pwnam( + *, src_db: ParsedPasswd, dst_db: ParsedPasswd, uid: int + ) -> int: + _mapped_uid = map_uid_by_pwnam(src_db=src_db, dst_db=dst_db, uid=uid) + _usern = src_db._by_uid[uid] + + logger.info(f"{_usern=}: mapping src_{uid=} to {_mapped_uid=}") + return _mapped_uid + + @staticmethod + def map_gid_by_grpnam(*, src_db: ParsedGroup, dst_db: ParsedGroup, gid: int) -> int: + _mapped_gid = map_gid_by_grpnam(src_db=src_db, dst_db=dst_db, gid=gid) + _groupn = src_db._by_gid[gid] + + logger.info(f"{_groupn=}: mapping src_{gid=} to {_mapped_gid=}") + return _mapped_gid + + def _chown_with_mapping( + self, _src_stat: os.stat_result, _dst_path: str | Path + ) -> None: + _src_uid, _src_gid = _src_stat.st_uid, _src_stat.st_gid + try: + _dst_uid = self._uid_mapper(uid=_src_uid) + except ValueError: + logger.warning(f"failed to find mapping for {_src_uid=}, keep unchanged") + _dst_uid = _src_uid + + try: + _dst_gid = self._gid_mapper(gid=_src_gid) + except ValueError: + logger.warning(f"failed to find mapping for {_src_gid=}, keep unchanged") + _dst_gid = _src_gid + os.chown(_dst_path, uid=_dst_uid, gid=_dst_gid, follow_symlinks=False) + + @staticmethod + def _rm_target(_target: Path) -> None: + """Remove target with proper methods.""" + if _target.is_symlink() or _target.is_file(): + return _target.unlink(missing_ok=True) + elif _target.is_dir(): + return shutil.rmtree(_target, ignore_errors=True) + elif _target.exists(): + raise ValueError( + f"{_target} is not normal file/symlink/dir, failed to remove" + ) + + def _prepare_symlink(self, _src_path: Path, _dst_path: Path) -> None: + _dst_path.symlink_to(os.readlink(_src_path)) + # NOTE: to get stat from symlink, using os.stat with follow_symlinks=False + self._chown_with_mapping(os.stat(_src_path, follow_symlinks=False), _dst_path) + + def _prepare_dir(self, _src_path: Path, _dst_path: Path) -> None: + _dst_path.mkdir(exist_ok=True) + + _src_stat = os.stat(_src_path, follow_symlinks=False) + os.chmod(_dst_path, _src_stat.st_mode) + self._chown_with_mapping(_src_stat, _dst_path) + + def _prepare_file(self, _src_path: Path, _dst_path: Path) -> None: + shutil.copy(_src_path, _dst_path, follow_symlinks=False) + + _src_stat = os.stat(_src_path, follow_symlinks=False) + os.chmod(_dst_path, _src_stat.st_mode) + self._chown_with_mapping(_src_stat, _dst_path) + + def _prepare_parent(self, _origin_entry: Path) -> None: + for _parent in reversed(_origin_entry.parents): + _src_parent, _dst_parent = ( + self._src_root / _parent, + self._dst_root / _parent, + ) + if _dst_parent.is_dir(): # keep the origin parent on dst as it + continue + if _dst_parent.is_symlink() or _dst_parent.is_file(): + _dst_parent.unlink(missing_ok=True) + self._prepare_dir(_src_parent, _dst_parent) + continue + if _dst_parent.exists(): + raise ValueError( + f"{_dst_parent=} is not a normal file/symlink/dir, cannot cleanup" + ) + self._prepare_dir(_src_parent, _dst_parent) + + # API + + def preserve_persist_entry( + self, _persist_entry: str | Path, *, src_missing_ok: bool = True + ): + logger.info(f"preserving {_persist_entry}") + # persist_entry in persists.txt must be rooted at / + origin_entry = Path(_persist_entry).relative_to("/") + src_path = self._src_root / origin_entry + dst_path = self._dst_root / origin_entry + + # ------ src is symlink ------ # + # NOTE: always check if symlink first as is_file/is_dir/exists all follow_symlinks + if src_path.is_symlink(): + self._rm_target(dst_path) + self._prepare_parent(origin_entry) + self._prepare_symlink(src_path, dst_path) + return + + # ------ src is file ------ # + if src_path.is_file(): + self._rm_target(dst_path) + self._prepare_parent(origin_entry) + self._prepare_file(src_path, dst_path) + return + + # ------ src is not regular file/symlink/dir ------ # + # we only process normal file/symlink/dir + if src_path.exists() and not src_path.is_dir(): + raise ValueError(f"{src_path=} must be either a file/symlink/dir") + + # ------ src doesn't exist ------ # + if not src_path.exists(): + _err_msg = f"{src_path=} not found" + logger.warning(_err_msg) + if not src_missing_ok: + raise ValueError(_err_msg) + return + + # ------ src is dir ------ # + # dive into src_dir and preserve everything under the src dir + self._prepare_parent(origin_entry) + for src_curdir, dnames, fnames in os.walk(src_path, followlinks=False): + src_cur_dpath = Path(src_curdir) + dst_cur_dpath = self._dst_root / src_cur_dpath.relative_to(self._src_root) + + # ------ prepare current dir itself ------ # + self._rm_target(dst_cur_dpath) + self._prepare_dir(src_cur_dpath, dst_cur_dpath) + + # ------ prepare entries in current dir ------ # + for _fname in fnames: + _src_fpath, _dst_fpath = src_cur_dpath / _fname, dst_cur_dpath / _fname + self._rm_target(_dst_fpath) + if _src_fpath.is_symlink(): + self._prepare_symlink(_src_fpath, _dst_fpath) + continue + self._prepare_file(_src_fpath, _dst_fpath) + + # symlinks to dirs also included in dnames, we must handle it + for _dname in dnames: + _src_dpath, _dst_dpath = src_cur_dpath / _dname, dst_cur_dpath / _dname + if _src_dpath.is_symlink(): + self._rm_target(_dst_dpath) + self._prepare_symlink(_src_dpath, _dst_dpath) diff --git a/src/otaclient/app/proto/streamer.py b/src/otaclient_common/proto_streamer.py similarity index 97% rename from src/otaclient/app/proto/streamer.py rename to src/otaclient_common/proto_streamer.py index fb887dc01..e52934730 100644 --- a/src/otaclient/app/proto/streamer.py +++ b/src/otaclient_common/proto_streamer.py @@ -23,7 +23,7 @@ from typing import BinaryIO, Generic, Iterable, Optional, Type -from ._common import MessageType, MessageWrapperType +from otaclient_common.proto_wrapper import MessageType, MessageWrapperType UINT32_LEN = 4 # bytes diff --git a/src/otaclient/app/proto/_common.py b/src/otaclient_common/proto_wrapper.py similarity index 100% rename from src/otaclient/app/proto/_common.py rename to src/otaclient_common/proto_wrapper.py diff --git a/src/otaclient/app/proto/README.md b/src/otaclient_common/proto_wrapper_README.md similarity index 100% rename from src/otaclient/app/proto/README.md rename to src/otaclient_common/proto_wrapper_README.md diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py new file mode 100644 index 000000000..80c10e061 --- /dev/null +++ b/src/otaclient_common/retry_task_map.py @@ -0,0 +1,225 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import itertools +import logging +import threading +import time +from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial +from queue import Queue +from typing import ( + Any, + Callable, + Generator, + Generic, + Iterable, + NamedTuple, + Optional, + Set, + TypeVar, +) + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class DoneTask(NamedTuple): + fut: Future + entry: Any + + +class RetryTaskMapInterrupted(Exception): + pass + + +class _TaskMap(Generic[T]): + def __init__( + self, + executor: ThreadPoolExecutor, + max_concurrent: int, + backoff_func: Callable[[int], float], + ) -> None: + # task dispatch interval for continues failling + self.started = False # can only be started once + self._backoff_func = backoff_func + self._executor = executor + self._shutdown_event = threading.Event() + self._se = threading.Semaphore(max_concurrent) + + self._total_tasks_count = 0 + self._dispatched_tasks: Set[Future] = set() + self._failed_tasks: Set[T] = set() + self._last_failed_fut: Optional[Future] = None + + # NOTE: itertools.count is only thread-safe in CPython with GIL, + # as itertools.count is pure C implemented, calling next over + # it is atomic in Python level. + self._done_task_counter = itertools.count(start=1) + self._all_done = threading.Event() + self._dispatch_done = False + + self._done_que: Queue[DoneTask] = Queue() + + def _done_task_cb(self, item: T, fut: Future): + """ + Tracking done counting, set all_done event. + add failed to failed list. + """ + self._se.release() # always release se first + # NOTE: don't change dispatched_tasks if shutdown_event is set + if self._shutdown_event.is_set(): + return + + self._dispatched_tasks.discard(fut) + # check if we finish all tasks + _done_task_num = next(self._done_task_counter) + if self._dispatch_done and _done_task_num == self._total_tasks_count: + logger.debug("all done!") + self._all_done.set() + + if fut.exception(): + self._failed_tasks.add(item) + self._last_failed_fut = fut + self._done_que.put_nowait(DoneTask(fut, item)) + + def _task_dispatcher(self, func: Callable[[T], Any], _iter: Iterable[T]): + """A dispatcher in a dedicated thread that dispatches + tasks to threadpool.""" + for item in _iter: + if self._shutdown_event.is_set(): + return + self._se.acquire() + self._total_tasks_count += 1 + + fut = self._executor.submit(func, item) + fut.add_done_callback(partial(self._done_task_cb, item)) + self._dispatched_tasks.add(fut) + logger.debug(f"dispatcher done: {self._total_tasks_count=}") + self._dispatch_done = True + + def _done_task_collector(self) -> Generator[DoneTask, None, None]: + """A generator for caller to yield done task from.""" + _count = 0 + while not self._shutdown_event.is_set(): + if self._all_done.is_set() and _count == self._total_tasks_count: + logger.debug("collector done!") + return + + yield self._done_que.get() + _count += 1 + + def map(self, func: Callable[[T], Any], _iter: Iterable[T]): + if self.started: + raise ValueError(f"{self.__class__} inst can only be started once") + self.started = True + + self._task_dispatcher_fut = self._executor.submit( + self._task_dispatcher, func, _iter + ) + self._task_collector_gen = self._done_task_collector() + return self._task_collector_gen + + def shutdown(self, *, raise_last_exc=False) -> Optional[Set[T]]: + """Set the shutdown event, and cancal/cleanup ongoing tasks.""" + if not self.started or self._shutdown_event.is_set(): + return + + self._shutdown_event.set() + self._task_collector_gen.close() + # wait for dispatch to stop + self._task_dispatcher_fut.result() + + # cancel all the dispatched tasks + for fut in self._dispatched_tasks: + fut.cancel() + self._dispatched_tasks.clear() + + if not self._failed_tasks: + return + try: + if self._last_failed_fut: + _exc = self._last_failed_fut.exception() + _err_msg = f"{len(self._failed_tasks)=}, last failed: {_exc!r}" + if raise_last_exc: + raise RetryTaskMapInterrupted(_err_msg) from _exc + else: + logger.warning(_err_msg) + return self._failed_tasks.copy() + finally: + # be careful not to create ref cycle here + self._failed_tasks.clear() + _exc, self = None, None + + +class RetryTaskMap(Generic[T]): + def __init__( + self, + *, + backoff_func: Callable[[int], float], + max_retry: int, + max_concurrent: int, + max_workers: Optional[int] = None, + ) -> None: + self._running_inst: Optional[_TaskMap] = None + self._map_gen: Optional[Generator] = None + + self._backoff_func = backoff_func + self._retry_counter = range(max_retry) if max_retry else itertools.count() + self._max_concurrent = max_concurrent + self._max_workers = max_workers + self._executor = ThreadPoolExecutor(max_workers=self._max_workers) + + def map( + self, _func: Callable[[T], Any], _iter: Iterable[T] + ) -> Generator[DoneTask, None, None]: + retry_round = 0 + for retry_round in self._retry_counter: + self._running_inst = _inst = _TaskMap( + self._executor, self._max_concurrent, self._backoff_func + ) + logger.debug(f"{retry_round=} started") + + yield from _inst.map(_func, _iter) + + # this retry round ends, check overall result + if _failed_list := _inst.shutdown(raise_last_exc=False): + _iter = _failed_list # feed failed to next round + # deref before entering sleep + self._running_inst, _inst = None, None + + logger.warning(f"retry#{retry_round+1}: retry on {len(_failed_list)=}") + time.sleep(self._backoff_func(retry_round)) + else: # all tasks finished successfully + self._running_inst, _inst = None, None + return + try: + raise RetryTaskMapInterrupted(f"exceed try limit: {retry_round}") + finally: + # cleanup the defs + _func, _iter = None, None # type: ignore + + def shutdown(self, *, raise_last_exc: bool): + try: + logger.debug("shutdown retry task map") + if self._running_inst: + self._running_inst.shutdown(raise_last_exc=raise_last_exc) + # NOTE: passthrough the exception from underlying running_inst + finally: + self._running_inst = None + self._executor.shutdown(wait=True) diff --git a/src/otaclient/_utils/typing.py b/src/otaclient_common/typing.py similarity index 55% rename from src/otaclient/_utils/typing.py rename to src/otaclient_common/typing.py index 941935f93..b487a818a 100644 --- a/src/otaclient/_utils/typing.py +++ b/src/otaclient_common/typing.py @@ -20,12 +20,13 @@ from typing import Any, Callable, TypeVar, Union from pydantic import Field -from typing_extensions import Annotated, ParamSpec +from typing_extensions import Annotated, Concatenate, ParamSpec P = ParamSpec("P") -T = TypeVar("T", bound=Enum) RT = TypeVar("RT") +T = TypeVar("T") +EnumT = TypeVar("EnumT", bound=Enum) StrOrPath = Union[str, Path] # pydantic helpers @@ -33,7 +34,9 @@ NetworkPort = Annotated[int, Field(ge=1, le=65535)] -def gen_strenum_validator(enum_type: type[T]) -> Callable[[T | str | Any], T]: +def gen_strenum_validator( + enum_type: type[EnumT], +) -> Callable[[EnumT | str | Any], EnumT]: """A before validator generator that converts input value into enum before passing it to pydantic validator. @@ -41,10 +44,34 @@ def gen_strenum_validator(enum_type: type[T]) -> Callable[[T | str | Any], T]: pass strict validation if input is str. """ - def _inner(value: T | str | Any) -> T: + def _inner(value: EnumT | str | Any) -> EnumT: assert isinstance( value, (enum_type, str) ), f"{value=} should be {enum_type} or str type" return enum_type(value) return _inner + + +def copy_callable_typehint(_source: Callable[P, Any]): + """This helper function return a decorator that can type hint the target + function as the _source function. + + At runtime, this decorator actually does nothing, but just return the input function as it. + But the returned function will have the same type hint as the source function in ide. + It will not impact the runtime behavior of the decorated function. + """ + + def _decorator(target) -> Callable[P, Any]: + return target + + return _decorator + + +def copy_callable_typehint_to_method(_source: Callable[P, Any]): + """Works the same as copy_callable_typehint, but omit the first arg.""" + + def _decorator(target: Callable[..., RT]) -> Callable[Concatenate[Any, P], RT]: + return target # type: ignore + + return _decorator diff --git a/tests/conftest.py b/tests/conftest.py index c923caf1a..2e060f7a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) +TEST_DIR = Path(__file__).parent + @dataclass class TestConfiguration: @@ -40,8 +42,7 @@ class TestConfiguration: RPI_BOOT_MODULE_PATH = "otaclient.app.boot_control._rpi_boot" OTACLIENT_MODULE_PATH = "otaclient.app.ota_client" OTACLIENT_STUB_MODULE_PATH = "otaclient.app.ota_client_stub" - OTACLIENT_SERVICE_MODULE_PATH = "otaclient.app.ota_client_service" - OTAMETA_MODULE_PATH = "otaclient.app.ota_metadata" + OTAMETA_MODULE_PATH = "ota_metadata.legacy.parser" OTAPROXY_MODULE_PATH = "ota_proxy" CREATE_STANDBY_MODULE_PATH = "otaclient.app.create_standby" MAIN_MODULE_PATH = "otaclient.app.main" diff --git a/tests/test__utils/test_logging.py b/tests/test_logging.py similarity index 97% rename from tests/test__utils/test_logging.py rename to tests/test_logging.py index cdfd627ce..ef3c15e18 100644 --- a/tests/test__utils/test_logging.py +++ b/tests/test_logging.py @@ -18,7 +18,7 @@ from pytest import LogCaptureFixture -from otaclient._utils import logging as _logging +from otaclient_common import logging as _logging def test_BurstSuppressFilter(caplog: LogCaptureFixture): diff --git a/tests/test_ota_metadata/__init__.py b/tests/test_ota_metadata/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_ota_metadata.py b/tests/test_ota_metadata/test_legacy.py similarity index 99% rename from tests/test_ota_metadata.py rename to tests/test_ota_metadata/test_legacy.py index eb5bb502d..74e550a8f 100644 --- a/tests/test_ota_metadata.py +++ b/tests/test_ota_metadata/test_legacy.py @@ -27,7 +27,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec -from otaclient.app.ota_metadata import ( +from ota_metadata.legacy.parser import ( MetadataJWTPayloadInvalid, MetadataJWTVerificationFailed, _MetadataJWTParser, @@ -36,6 +36,9 @@ parse_regulars_from_txt, parse_symlinks_from_txt, ) +from tests.conftest import TEST_DIR + +GEN_CERTS_SCRIPT = TEST_DIR / "keys" / "gen_certs.sh" HEADER = """\ {"alg": "ES256"}\ @@ -162,10 +165,6 @@ def generate_jwt(payload_str: str, sign_key_file: Path): return f"{header}.{payload}.{signature}" -TEST_DIR = Path(__file__).parent -GEN_CERTS_SCRIPT = TEST_DIR / "keys" / "gen_certs.sh" - - class CertsDirs(TypedDict): multi_chain: Path chain_a: Path diff --git a/tests/test_ota_proxy/test_ota_proxy_server.py b/tests/test_ota_proxy/test_ota_proxy_server.py index c38015cda..361ef9868 100644 --- a/tests/test_ota_proxy/test_ota_proxy_server.py +++ b/tests/test_ota_proxy/test_ota_proxy_server.py @@ -27,9 +27,9 @@ import pytest import uvicorn +from ota_metadata.legacy.parser import parse_regulars_from_txt +from ota_metadata.legacy.types import RegularInf from ota_proxy.utils import url_based_hash -from otaclient.app.ota_metadata import parse_regulars_from_txt -from otaclient.app.proto.wrapper import RegularInf from tests.conftest import ThreadpoolExecutorFixtureMixin, cfg logger = logging.getLogger(__name__) diff --git a/tests/test_otaclient/__init__.py b/tests/test_otaclient/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_boot_control/__init__.py b/tests/test_otaclient/test_boot_control/__init__.py similarity index 100% rename from tests/test_boot_control/__init__.py rename to tests/test_otaclient/test_boot_control/__init__.py diff --git a/tests/test_boot_control/default_grub b/tests/test_otaclient/test_boot_control/default_grub similarity index 100% rename from tests/test_boot_control/default_grub rename to tests/test_otaclient/test_boot_control/default_grub diff --git a/tests/test_boot_control/extlinux.conf_slot_a b/tests/test_otaclient/test_boot_control/extlinux.conf_slot_a similarity index 100% rename from tests/test_boot_control/extlinux.conf_slot_a rename to tests/test_otaclient/test_boot_control/extlinux.conf_slot_a diff --git a/tests/test_boot_control/extlinux.conf_slot_b b/tests/test_otaclient/test_boot_control/extlinux.conf_slot_b similarity index 100% rename from tests/test_boot_control/extlinux.conf_slot_b rename to tests/test_otaclient/test_boot_control/extlinux.conf_slot_b diff --git a/tests/test_boot_control/fstab_origin b/tests/test_otaclient/test_boot_control/fstab_origin similarity index 100% rename from tests/test_boot_control/fstab_origin rename to tests/test_otaclient/test_boot_control/fstab_origin diff --git a/tests/test_boot_control/fstab_updated b/tests/test_otaclient/test_boot_control/fstab_updated similarity index 100% rename from tests/test_boot_control/fstab_updated rename to tests/test_otaclient/test_boot_control/fstab_updated diff --git a/tests/test_boot_control/grub.cfg_slot_a b/tests/test_otaclient/test_boot_control/grub.cfg_slot_a similarity index 100% rename from tests/test_boot_control/grub.cfg_slot_a rename to tests/test_otaclient/test_boot_control/grub.cfg_slot_a diff --git a/tests/test_boot_control/grub.cfg_slot_a_non_otapartition b/tests/test_otaclient/test_boot_control/grub.cfg_slot_a_non_otapartition similarity index 100% rename from tests/test_boot_control/grub.cfg_slot_a_non_otapartition rename to tests/test_otaclient/test_boot_control/grub.cfg_slot_a_non_otapartition diff --git a/tests/test_boot_control/grub.cfg_slot_a_updated b/tests/test_otaclient/test_boot_control/grub.cfg_slot_a_updated similarity index 100% rename from tests/test_boot_control/grub.cfg_slot_a_updated rename to tests/test_otaclient/test_boot_control/grub.cfg_slot_a_updated diff --git a/tests/test_boot_control/grub.cfg_slot_b b/tests/test_otaclient/test_boot_control/grub.cfg_slot_b similarity index 100% rename from tests/test_boot_control/grub.cfg_slot_b rename to tests/test_otaclient/test_boot_control/grub.cfg_slot_b diff --git a/tests/test_boot_control/grub.cfg_slot_b_updated b/tests/test_otaclient/test_boot_control/grub.cfg_slot_b_updated similarity index 100% rename from tests/test_boot_control/grub.cfg_slot_b_updated rename to tests/test_otaclient/test_boot_control/grub.cfg_slot_b_updated diff --git a/tests/test_boot_control/test_grub.py b/tests/test_otaclient/test_boot_control/test_grub.py similarity index 98% rename from tests/test_boot_control/test_grub.py rename to tests/test_otaclient/test_boot_control/test_grub.py index b8560e837..8a3cd1401 100644 --- a/tests/test_boot_control/test_grub.py +++ b/tests/test_otaclient/test_boot_control/test_grub.py @@ -22,7 +22,7 @@ import pytest import pytest_mock -from otaclient.app.proto import wrapper +from otaclient_api.v2 import types as api_types from tests.conftest import TestConfiguration as cfg from tests.utils import SlotMeta @@ -323,7 +323,7 @@ def test_grub_normal_update(self, mocker: pytest_mock.MockerFixture): grub_controller = GrubController() assert ( self.slot_a_ota_partition_dir / "status" - ).read_text() == wrapper.StatusOta.INITIALIZED.name + ).read_text() == api_types.StatusOta.INITIALIZED.name # assert ota-partition file points to slot_a ota-partition folder assert ( os.readlink(self.boot_dir / cfg.OTA_PARTITION_DIRNAME) @@ -342,10 +342,10 @@ def test_grub_normal_update(self, mocker: pytest_mock.MockerFixture): # update slot_b, slot_a_ota_status->FAILURE, slot_b_ota_status->UPDATING assert ( self.slot_a_ota_partition_dir / "status" - ).read_text() == wrapper.StatusOta.FAILURE.name + ).read_text() == api_types.StatusOta.FAILURE.name assert ( self.slot_b_ota_partition_dir / "status" - ).read_text() == wrapper.StatusOta.UPDATING.name + ).read_text() == api_types.StatusOta.UPDATING.name # NOTE: we have to copy the new kernel files to the slot_b's boot dir # this is done by the create_standby module _kernel = f"{cfg.KERNEL_PREFIX}-{cfg.KERNEL_VERSION}" @@ -385,7 +385,7 @@ def test_grub_normal_update(self, mocker: pytest_mock.MockerFixture): assert self._fsm.is_boot_switched assert ( self.slot_b_ota_partition_dir / "status" - ).read_text() == wrapper.StatusOta.UPDATING.name + ).read_text() == api_types.StatusOta.UPDATING.name # assert ota-partition file is not yet switched before first reboot init assert ( os.readlink(self.boot_dir / cfg.OTA_PARTITION_DIRNAME) @@ -401,7 +401,7 @@ def test_grub_normal_update(self, mocker: pytest_mock.MockerFixture): ) assert ( self.slot_b_ota_partition_dir / "status" - ).read_text() == wrapper.StatusOta.SUCCESS.name + ).read_text() == api_types.StatusOta.SUCCESS.name assert ( self.slot_b_ota_partition_dir / "version" ).read_text() == cfg.UPDATE_VERSION diff --git a/tests/test_boot_control/test_jetson_cboot.py b/tests/test_otaclient/test_boot_control/test_jetson_cboot.py similarity index 95% rename from tests/test_boot_control/test_jetson_cboot.py rename to tests/test_otaclient/test_boot_control/test_jetson_cboot.py index 8183eb57f..937c6b677 100644 --- a/tests/test_boot_control/test_jetson_cboot.py +++ b/tests/test_otaclient/test_boot_control/test_jetson_cboot.py @@ -33,11 +33,12 @@ parse_bsp_version, update_extlinux_cfg, ) +from tests.conftest import TEST_DIR logger = logging.getLogger(__name__) MODULE_NAME = _jetson_cboot.__name__ -TEST_DIR = Path(__file__).parent.parent / "data" +TEST_DATA_DIR = TEST_DIR / "data" def test_SlotID(): @@ -139,6 +140,6 @@ def test_parse_bsp_version(_in: str, expected: BSPVersion): ), ) def test_update_extlinux_conf(_template_f: Path, _updated_f: Path, partuuid: str): - _in = (TEST_DIR / _template_f).read_text() - _expected = (TEST_DIR / _updated_f).read_text() + _in = (TEST_DATA_DIR / _template_f).read_text() + _expected = (TEST_DATA_DIR / _updated_f).read_text() assert update_extlinux_cfg(_in, partuuid) == _expected diff --git a/tests/test_boot_control/test_ota_status_control.py b/tests/test_otaclient/test_boot_control/test_ota_status_control.py similarity index 86% rename from tests/test_boot_control/test_ota_status_control.py rename to tests/test_otaclient/test_boot_control/test_ota_status_control.py index 053ad88d5..6a3d2b6a6 100644 --- a/tests/test_boot_control/test_ota_status_control.py +++ b/tests/test_otaclient/test_boot_control/test_ota_status_control.py @@ -23,8 +23,8 @@ from otaclient.app.boot_control._common import OTAStatusFilesControl from otaclient.app.boot_control.configs import BaseConfig as cfg -from otaclient.app.common import read_str_from_file, write_str_to_file -from otaclient.app.proto import wrapper +from otaclient_api.v2 import types as api_types +from otaclient_common.common import read_str_from_file, write_str_to_file logger = logging.getLogger(__name__) @@ -74,27 +74,27 @@ def setup(self, tmp_path: Path): "", False, # output - wrapper.StatusOta.INITIALIZED, + api_types.StatusOta.INITIALIZED, SLOT_A_ID, ), ( "test_force_initialize", # input - wrapper.StatusOta.SUCCESS, + api_types.StatusOta.SUCCESS, SLOT_A_ID, True, # output - wrapper.StatusOta.INITIALIZED, + api_types.StatusOta.INITIALIZED, SLOT_A_ID, ), ( "test_normal_boot", # input - wrapper.StatusOta.SUCCESS, + api_types.StatusOta.SUCCESS, SLOT_A_ID, False, # output - wrapper.StatusOta.SUCCESS, + api_types.StatusOta.SUCCESS, SLOT_A_ID, ), ), @@ -102,10 +102,10 @@ def setup(self, tmp_path: Path): def test_ota_status_files_loading( self, test_case: str, - input_slot_a_status: Optional[wrapper.StatusOta], + input_slot_a_status: Optional[api_types.StatusOta], input_slot_a_slot_in_use: str, force_initialize: bool, - output_slot_a_status: wrapper.StatusOta, + output_slot_a_status: api_types.StatusOta, output_slot_a_slot_in_use: str, ): logger.info(f"{test_case=}") @@ -158,7 +158,7 @@ def test_pre_update(self): # slot_a: current slot assert ( read_str_from_file(self.slot_a_status_file) - == wrapper.StatusOta.FAILURE.name + == api_types.StatusOta.FAILURE.name ) assert ( read_str_from_file(self.slot_a_slot_in_use_file) @@ -168,7 +168,7 @@ def test_pre_update(self): # slot_b: standby slot assert ( read_str_from_file(self.slot_b_status_file) - == wrapper.StatusOta.UPDATING.name + == api_types.StatusOta.UPDATING.name ) assert read_str_from_file(self.slot_b_slot_in_use_file) == self.slot_b @@ -193,9 +193,9 @@ def test_switching_boot( """First reboot after OTA from slot_a to slot_b.""" logger.info(f"{test_case=}") # ------ setup ------ # - write_str_to_file(self.slot_a_status_file, wrapper.StatusOta.FAILURE.name) + write_str_to_file(self.slot_a_status_file, api_types.StatusOta.FAILURE.name) write_str_to_file(self.slot_a_slot_in_use_file, self.slot_b) - write_str_to_file(self.slot_b_status_file, wrapper.StatusOta.UPDATING.name) + write_str_to_file(self.slot_b_status_file, api_types.StatusOta.UPDATING.name) write_str_to_file(self.slot_b_slot_in_use_file, self.slot_b) # ------ execution ------ # @@ -218,7 +218,7 @@ def test_switching_boot( # check slot a assert ( read_str_from_file(self.slot_a_status_file) - == wrapper.StatusOta.FAILURE.name + == api_types.StatusOta.FAILURE.name ) assert ( read_str_from_file(self.slot_a_slot_in_use_file) @@ -233,25 +233,25 @@ def test_switching_boot( # finalizing succeeded if finalizing_result: - assert status_control.booted_ota_status == wrapper.StatusOta.SUCCESS + assert status_control.booted_ota_status == api_types.StatusOta.SUCCESS assert ( read_str_from_file(self.slot_b_status_file) - == wrapper.StatusOta.SUCCESS.name + == api_types.StatusOta.SUCCESS.name ) else: - assert status_control.booted_ota_status == wrapper.StatusOta.FAILURE + assert status_control.booted_ota_status == api_types.StatusOta.FAILURE assert ( read_str_from_file(self.slot_b_status_file) - == wrapper.StatusOta.FAILURE.name + == api_types.StatusOta.FAILURE.name ) def test_accidentally_boots_back_to_standby(self): """slot_a should be active slot but boots back to slot_b.""" # ------ setup ------ # - write_str_to_file(self.slot_a_status_file, wrapper.StatusOta.SUCCESS.name) + write_str_to_file(self.slot_a_status_file, api_types.StatusOta.SUCCESS.name) write_str_to_file(self.slot_a_slot_in_use_file, self.slot_a) - write_str_to_file(self.slot_b_status_file, wrapper.StatusOta.FAILURE.name) + write_str_to_file(self.slot_b_status_file, api_types.StatusOta.FAILURE.name) write_str_to_file(self.slot_b_slot_in_use_file, self.slot_a) # ------ execution ------ # @@ -268,4 +268,4 @@ def test_accidentally_boots_back_to_standby(self): # ------ assertion ------ # assert not self.finalize_switch_boot_flag.is_set() # slot_b's status is read - assert status_control.booted_ota_status == wrapper.StatusOta.FAILURE + assert status_control.booted_ota_status == api_types.StatusOta.FAILURE diff --git a/tests/test_boot_control/test_rpi_boot.py b/tests/test_otaclient/test_boot_control/test_rpi_boot.py similarity index 97% rename from tests/test_boot_control/test_rpi_boot.py rename to tests/test_otaclient/test_boot_control/test_rpi_boot.py index 59164681d..f4b5b943e 100644 --- a/tests/test_boot_control/test_rpi_boot.py +++ b/tests/test_otaclient/test_boot_control/test_rpi_boot.py @@ -10,7 +10,7 @@ from otaclient.app.boot_control._rpi_boot import _FSTAB_TEMPLATE_STR from otaclient.app.boot_control.configs import rpi_boot_cfg -from otaclient.app.proto import wrapper +from otaclient_api.v2 import types as api_types from tests.conftest import TestConfiguration as cfg from tests.utils import SlotMeta @@ -223,10 +223,10 @@ def test_rpi_boot_normal_update(self, mocker: pytest_mock.MockerFixture): # 2. make sure the mount points are prepared assert ( self.slot_a_ota_status_dir / "status" - ).read_text() == wrapper.StatusOta.FAILURE.name + ).read_text() == api_types.StatusOta.FAILURE.name assert ( self.slot_b_ota_status_dir / "status" - ).read_text() == wrapper.StatusOta.UPDATING.name + ).read_text() == api_types.StatusOta.UPDATING.name assert ( (self.slot_a_ota_status_dir / "slot_in_use").read_text() == (self.slot_b_ota_status_dir / "slot_in_use").read_text() @@ -309,7 +309,7 @@ def test_rpi_boot_normal_update(self, mocker: pytest_mock.MockerFixture): assert (self.system_boot / rpi_boot_cfg.SWITCH_BOOT_FLAG_FILE).is_file() assert ( self.slot_b_ota_status_dir / rpi_boot_cfg.OTA_STATUS_FNAME - ).read_text() == wrapper.StatusOta.UPDATING.name + ).read_text() == api_types.StatusOta.UPDATING.name # ------ boot_controller_inst3.stage1: second reboot, apply updated firmware and finish up ota update ------ # logger.info("2nd reboot: finish up ota update....") @@ -320,11 +320,12 @@ def test_rpi_boot_normal_update(self, mocker: pytest_mock.MockerFixture): # 2. make sure the flag file is cleared # 3. make sure the config.txt is still for slot_b assert ( - rpi_boot_controller4_2.get_booted_ota_status() == wrapper.StatusOta.SUCCESS + rpi_boot_controller4_2.get_booted_ota_status() + == api_types.StatusOta.SUCCESS ) assert ( self.slot_b_ota_status_dir / rpi_boot_cfg.OTA_STATUS_FNAME - ).read_text() == wrapper.StatusOta.SUCCESS.name + ).read_text() == api_types.StatusOta.SUCCESS.name assert not (self.system_boot / rpi_boot_cfg.SWITCH_BOOT_FLAG_FILE).is_file() assert ( rpi_boot_controller4_2._ota_status_control._load_current_slot_in_use() diff --git a/tests/test_ecu_info.py b/tests/test_otaclient/test_configs/test_ecu_info.py similarity index 100% rename from tests/test_ecu_info.py rename to tests/test_otaclient/test_configs/test_ecu_info.py diff --git a/tests/test_proxy_info.py b/tests/test_otaclient/test_configs/test_proxy_info.py similarity index 100% rename from tests/test_proxy_info.py rename to tests/test_otaclient/test_configs/test_proxy_info.py diff --git a/tests/test_create_standby.py b/tests/test_otaclient/test_create_standby.py similarity index 99% rename from tests/test_create_standby.py rename to tests/test_otaclient/test_create_standby.py index f4845129c..5b3156e1b 100644 --- a/tests/test_create_standby.py +++ b/tests/test_otaclient/test_create_standby.py @@ -78,7 +78,6 @@ def mock_setup(self, mocker: MockerFixture, prepare_ab_slots): _cfg.RUN_DIR = str(self.otaclient_run_dir) # type: ignore mocker.patch(f"{cfg.OTACLIENT_MODULE_PATH}.cfg", _cfg) mocker.patch(f"{cfg.CREATE_STANDBY_MODULE_PATH}.rebuild_mode.cfg", _cfg) - mocker.patch(f"{cfg.OTAMETA_MODULE_PATH}.cfg", _cfg) def test_update_with_create_standby_RebuildMode(self, mocker: MockerFixture): from otaclient.app.create_standby.rebuild_mode import RebuildMode diff --git a/tests/test_log_setting.py b/tests/test_otaclient/test_log_setting.py similarity index 100% rename from tests/test_log_setting.py rename to tests/test_otaclient/test_log_setting.py diff --git a/tests/test_main.py b/tests/test_otaclient/test_main.py similarity index 100% rename from tests/test_main.py rename to tests/test_otaclient/test_main.py diff --git a/tests/test_ota_client.py b/tests/test_otaclient/test_ota_client.py similarity index 90% rename from tests/test_ota_client.py rename to tests/test_otaclient/test_ota_client.py index 5496d647a..0fa260243 100644 --- a/tests/test_ota_client.py +++ b/tests/test_otaclient/test_ota_client.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio import shutil import threading @@ -25,6 +27,8 @@ import pytest import pytest_mock +from ota_metadata.legacy.parser import parse_dirs_from_txt, parse_regulars_from_txt +from ota_metadata.legacy.types import DirectoryInf, RegularInf from otaclient.app.boot_control import BootControllerProtocol from otaclient.app.boot_control.configs import BootloaderType from otaclient.app.configs import config as otaclient_cfg @@ -37,10 +41,8 @@ OTAServicer, _OTAUpdater, ) -from otaclient.app.ota_metadata import parse_dirs_from_txt, parse_regulars_from_txt -from otaclient.app.proto import wrapper -from otaclient.app.proto.wrapper import DirectoryInf, RegularInf from otaclient.configs.ecu_info import ECUInfo +from otaclient_api.v2 import types as api_types from tests.conftest import TestConfiguration as cfg from tests.utils import SlotMeta @@ -148,7 +150,6 @@ def mock_setup(self, mocker: pytest_mock.MockerFixture, _delta_generate): _cfg.ACTIVE_ROOTFS_PATH = str(self.slot_a) # type: ignore _cfg.RUN_DIR = str(self.otaclient_run_dir) # type: ignore mocker.patch(f"{cfg.OTACLIENT_MODULE_PATH}.cfg", _cfg) - mocker.patch(f"{cfg.OTAMETA_MODULE_PATH}.cfg", _cfg) # ------ mock stats collector ------ # mocker.patch( @@ -201,7 +202,7 @@ class Test_OTAClient: CURRENT_FIRMWARE_VERSION = "firmware_version" UPDATE_FIRMWARE_VERSION = "update_firmware_version" - MOCKED_STATUS_PROGRESS = wrapper.UpdateStatus( + MOCKED_STATUS_PROGRESS = api_types.UpdateStatus( update_firmware_version=UPDATE_FIRMWARE_VERSION, downloaded_bytes=456789, downloaded_files_num=567, @@ -233,7 +234,7 @@ def mock_setup(self, mocker: pytest_mock.MockerFixture): # patch boot_controller for otaclient initializing self.boot_controller.load_version.return_value = self.CURRENT_FIRMWARE_VERSION self.boot_controller.get_booted_ota_status.return_value = ( - wrapper.StatusOta.SUCCESS + api_types.StatusOta.SUCCESS ) self.ota_client = OTAClient( @@ -270,7 +271,7 @@ def test_update_normal_finished(self): self.ota_lock.release.assert_called_once() assert ( self.ota_client.live_ota_status.get_ota_status() - == wrapper.StatusOta.UPDATING + == api_types.StatusOta.UPDATING ) def test_update_interrupted(self): @@ -296,9 +297,9 @@ def test_update_interrupted(self): assert ( self.ota_client.live_ota_status.get_ota_status() - == wrapper.StatusOta.FAILURE + == api_types.StatusOta.FAILURE ) - assert self.ota_client.last_failure_type == wrapper.FailureType.RECOVERABLE + assert self.ota_client.last_failure_type == api_types.FailureType.RECOVERABLE def test_rollback(self): # TODO @@ -309,49 +310,49 @@ def test_status_not_in_update(self): _status = self.ota_client.status() # assert v2 to v1 conversion - assert _status.convert_to_v1() == wrapper.StatusResponseEcu( + assert _status.convert_to_v1() == api_types.StatusResponseEcu( ecu_id=self.MY_ECU_ID, - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( version=self.CURRENT_FIRMWARE_VERSION, - status=wrapper.StatusOta.SUCCESS, + status=api_types.StatusOta.SUCCESS, ), ) # assert to v2 - assert _status == wrapper.StatusResponseEcuV2( + assert _status == api_types.StatusResponseEcuV2( ecu_id=self.MY_ECU_ID, otaclient_version=self.OTACLIENT_VERSION, firmware_version=self.CURRENT_FIRMWARE_VERSION, - failure_type=wrapper.FailureType.NO_FAILURE, - ota_status=wrapper.StatusOta.SUCCESS, + failure_type=api_types.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.SUCCESS, ) def test_status_in_update(self): # --- mock setup --- # # inject ota_updater and set ota_status to UPDATING to simulate ota updating self.ota_client._update_executor = self.ota_updater - self.ota_client.live_ota_status.set_ota_status(wrapper.StatusOta.UPDATING) + self.ota_client.live_ota_status.set_ota_status(api_types.StatusOta.UPDATING) # let mocked updater return mocked_status_progress self.ota_updater.get_update_status.return_value = self.MOCKED_STATUS_PROGRESS # --- assertion --- # _status = self.ota_client.status() # test v2 to v1 conversion - assert _status.convert_to_v1() == wrapper.StatusResponseEcu( + assert _status.convert_to_v1() == api_types.StatusResponseEcu( ecu_id=self.MY_ECU_ID, - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( version=self.CURRENT_FIRMWARE_VERSION, - status=wrapper.StatusOta.UPDATING, + status=api_types.StatusOta.UPDATING, progress=self.MOCKED_STATUS_PROGRESS_V1, ), ) # assert to v2 - assert _status == wrapper.StatusResponseEcuV2( + assert _status == api_types.StatusResponseEcuV2( ecu_id=self.MY_ECU_ID, otaclient_version=self.OTACLIENT_VERSION, - failure_type=wrapper.FailureType.NO_FAILURE, - ota_status=wrapper.StatusOta.UPDATING, + failure_type=api_types.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.UPDATING, firmware_version=self.CURRENT_FIRMWARE_VERSION, update_status=self.MOCKED_STATUS_PROGRESS, ) @@ -423,7 +424,7 @@ def test_stub_initializing(self): assert self.otaclient_stub.local_used_proxy_url is self.local_use_proxy async def test_dispatch_update(self): - update_request_ecu = wrapper.UpdateRequestEcu( + update_request_ecu = api_types.UpdateRequestEcu( ecu_id=self.ECU_INFO.ecu_id, version="version", url="url", @@ -442,11 +443,11 @@ def _updating(*args, **kwargs): await self.otaclient_stub.dispatch_update(update_request_ecu) await asyncio.sleep(0.1) # wait for inner async closure to run - assert self.otaclient_stub.last_operation is wrapper.StatusOta.UPDATING + assert self.otaclient_stub.last_operation is api_types.StatusOta.UPDATING assert self.otaclient_stub.is_busy # test ota update/rollback exclusive lock, resp = await self.otaclient_stub.dispatch_update(update_request_ecu) - assert resp.result == wrapper.FailureType.RECOVERABLE + assert resp.result == api_types.FailureType.RECOVERABLE # finish up update _updating_event.set() @@ -469,14 +470,16 @@ def _rollbacking(*args, **kwargs): self.otaclient.rollback.side_effect = _rollbacking # dispatch rollback - await self.otaclient_stub.dispatch_rollback(wrapper.RollbackRequestEcu()) + await self.otaclient_stub.dispatch_rollback(api_types.RollbackRequestEcu()) await asyncio.sleep(0.1) # wait for inner async closure to run - assert self.otaclient_stub.last_operation is wrapper.StatusOta.ROLLBACKING + assert self.otaclient_stub.last_operation is api_types.StatusOta.ROLLBACKING assert self.otaclient_stub.is_busy # test ota update/rollback exclusive lock, - resp = await self.otaclient_stub.dispatch_rollback(wrapper.RollbackRequestEcu()) - assert resp.result == wrapper.FailureType.RECOVERABLE + resp = await self.otaclient_stub.dispatch_rollback( + api_types.RollbackRequestEcu() + ) + assert resp.result == api_types.FailureType.RECOVERABLE # finish up rollback _rollbacking_event.set() diff --git a/tests/test_ota_client_service.py b/tests/test_otaclient/test_ota_client_service.py similarity index 72% rename from tests/test_ota_client_service.py rename to tests/test_otaclient/test_ota_client_service.py index 6b7c7776a..ee1ca15cc 100644 --- a/tests/test_ota_client_service.py +++ b/tests/test_otaclient/test_ota_client_service.py @@ -21,34 +21,36 @@ import pytest_mock from otaclient.app.configs import server_cfg -from otaclient.app.ota_client_call import OtaClientCall -from otaclient.app.ota_client_service import create_otaclient_grpc_server -from otaclient.app.proto import wrapper +from otaclient.app.main import create_otaclient_grpc_server from otaclient.configs.ecu_info import ECUInfo +from otaclient_api.v2 import types as api_types +from otaclient_api.v2.api_caller import OTAClientCall from tests.conftest import cfg from tests.utils import compare_message +OTACLIENT_APP_MAIN = "otaclient.app.main" + class _MockedOTAClientServiceStub: MY_ECU_ID = "autoware" - UPDATE_RESP_ECU = wrapper.UpdateResponseEcu( + UPDATE_RESP_ECU = api_types.UpdateResponseEcu( ecu_id=MY_ECU_ID, - result=wrapper.FailureType.NO_FAILURE, + result=api_types.FailureType.NO_FAILURE, ) - UPDATE_RESP = wrapper.UpdateResponse(ecu=[UPDATE_RESP_ECU]) - ROLLBACK_RESP_ECU = wrapper.RollbackResponseEcu( + UPDATE_RESP = api_types.UpdateResponse(ecu=[UPDATE_RESP_ECU]) + ROLLBACK_RESP_ECU = api_types.RollbackResponseEcu( ecu_id=MY_ECU_ID, - result=wrapper.FailureType.NO_FAILURE, + result=api_types.FailureType.NO_FAILURE, ) - ROLLBACK_RESP = wrapper.RollbackResponse(ecu=[ROLLBACK_RESP_ECU]) - STATUS_RESP_ECU = wrapper.StatusResponseEcuV2( + ROLLBACK_RESP = api_types.RollbackResponse(ecu=[ROLLBACK_RESP_ECU]) + STATUS_RESP_ECU = api_types.StatusResponseEcuV2( ecu_id=MY_ECU_ID, otaclient_version="mocked_otaclient", firmware_version="firmware", - ota_status=wrapper.StatusOta.SUCCESS, - failure_type=wrapper.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.SUCCESS, + failure_type=api_types.FailureType.NO_FAILURE, ) - STATUS_RESP = wrapper.StatusResponse( + STATUS_RESP = api_types.StatusResponse( available_ecu_ids=[MY_ECU_ID], ecu_v2=[STATUS_RESP_ECU] ) @@ -71,7 +73,7 @@ class Test_ota_client_service: def setup_test(self, mocker: pytest_mock.MockerFixture): self.otaclient_service_stub = _MockedOTAClientServiceStub() mocker.patch( - f"{cfg.OTACLIENT_SERVICE_MODULE_PATH}.OTAClientServiceStub", + f"{OTACLIENT_APP_MAIN}.OTAClientServiceStub", return_value=self.otaclient_service_stub, ) @@ -79,7 +81,7 @@ def setup_test(self, mocker: pytest_mock.MockerFixture): ecu_id=self.otaclient_service_stub.MY_ECU_ID, ip_addr=self.LISTEN_ADDR, # type: ignore ) - mocker.patch(f"{cfg.OTACLIENT_SERVICE_MODULE_PATH}.ecu_info", ecu_info_mock) + mocker.patch(f"{OTACLIENT_APP_MAIN}.ecu_info", ecu_info_mock) @pytest.fixture(autouse=True) async def launch_otaclient_server(self, setup_test): @@ -93,28 +95,28 @@ async def launch_otaclient_server(self, setup_test): async def test_otaclient_service(self): # --- test update call --- # - update_resp = await OtaClientCall.update_call( + update_resp = await OTAClientCall.update_call( ecu_id=self.MY_ECU_ID, ecu_ipaddr=self.LISTEN_ADDR, ecu_port=self.LISTEN_PORT, - request=wrapper.UpdateRequest(), + request=api_types.UpdateRequest(), ) compare_message(update_resp, self.otaclient_service_stub.UPDATE_RESP) # --- test rollback call --- # - rollback_resp = await OtaClientCall.rollback_call( + rollback_resp = await OTAClientCall.rollback_call( ecu_id=self.MY_ECU_ID, ecu_ipaddr=self.LISTEN_ADDR, ecu_port=self.LISTEN_PORT, - request=wrapper.RollbackRequest(), + request=api_types.RollbackRequest(), ) compare_message(rollback_resp, self.otaclient_service_stub.ROLLBACK_RESP) # --- test status call --- # - status_resp = await OtaClientCall.status_call( + status_resp = await OTAClientCall.status_call( ecu_id=self.MY_ECU_ID, ecu_ipaddr=self.LISTEN_ADDR, ecu_port=self.LISTEN_PORT, - request=wrapper.StatusRequest(), + request=api_types.StatusRequest(), ) compare_message(status_resp, self.otaclient_service_stub.STATUS_RESP) diff --git a/tests/test_ota_client_stub.py b/tests/test_otaclient/test_ota_client_stub.py similarity index 73% rename from tests/test_ota_client_stub.py rename to tests/test_otaclient/test_ota_client_stub.py index 77734df46..3c57cc7bb 100644 --- a/tests/test_ota_client_stub.py +++ b/tests/test_otaclient/test_ota_client_stub.py @@ -27,15 +27,15 @@ from ota_proxy import OTAProxyContextProto from ota_proxy.config import Config as otaproxyConfig from otaclient.app.ota_client import OTAServicer -from otaclient.app.ota_client_call import OtaClientCall from otaclient.app.ota_client_stub import ( ECUStatusStorage, OTAClientServiceStub, OTAProxyLauncher, ) -from otaclient.app.proto import wrapper from otaclient.configs.ecu_info import ECUInfo, parse_ecu_info from otaclient.configs.proxy_info import ProxyInfo, parse_proxy_info +from otaclient_api.v2 import types as api_types +from otaclient_api.v2.api_caller import OTAClientCall from tests.conftest import cfg from tests.utils import compare_message @@ -184,33 +184,33 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): # case 1 ( # local ECU's status report - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.SUCCESS, + ota_status=api_types.StatusOta.SUCCESS, firmware_version="123.x", - failure_type=wrapper.FailureType.NO_FAILURE, + failure_type=api_types.FailureType.NO_FAILURE, ), # sub ECU's status report [ - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.SUCCESS, + ota_status=api_types.StatusOta.SUCCESS, firmware_version="123.x", - failure_type=wrapper.FailureType.NO_FAILURE, + failure_type=api_types.FailureType.NO_FAILURE, ) ], ), - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, version="123.x", ), ), @@ -218,46 +218,46 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): ), ], # expected export - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["autoware", "p1", "p2"], # explicitly v1 format compatibility ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="autoware", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, version="123.x", ), ), - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p1", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, version="123.x", ), ), - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, version="123.x", ), ), ], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.SUCCESS, - failure_type=wrapper.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.SUCCESS, + failure_type=api_types.FailureType.NO_FAILURE, firmware_version="123.x", ), - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.SUCCESS, - failure_type=wrapper.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.SUCCESS, + failure_type=api_types.FailureType.NO_FAILURE, firmware_version="123.x", ), ], @@ -266,15 +266,15 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): # case 2 ( # local ecu status report - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.UPDATING, + ota_status=api_types.StatusOta.UPDATING, firmware_version="123.x", - failure_type=wrapper.FailureType.NO_FAILURE, - update_status=wrapper.UpdateStatus( + failure_type=api_types.FailureType.NO_FAILURE, + update_status=api_types.UpdateStatus( update_firmware_version="789.x", - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES, - total_elapsed_time=wrapper.Duration(seconds=123), + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES, + total_elapsed_time=api_types.Duration(seconds=123), total_files_num=123456, processed_files_num=123, processed_files_size=456, @@ -285,18 +285,18 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): ), # sub ECUs' status report [ - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.UPDATING, + ota_status=api_types.StatusOta.UPDATING, firmware_version="123.x", - failure_type=wrapper.FailureType.NO_FAILURE, - update_status=wrapper.UpdateStatus( + failure_type=api_types.FailureType.NO_FAILURE, + update_status=api_types.UpdateStatus( update_firmware_version="789.x", - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES, - total_elapsed_time=wrapper.Duration(seconds=123), + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES, + total_elapsed_time=api_types.Duration(seconds=123), total_files_num=123456, processed_files_num=123, processed_files_size=456, @@ -307,14 +307,14 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): ) ], ), - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, version="123.x", ), ), @@ -322,20 +322,20 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): ), ], # expected export result - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["autoware", "p1", "p2"], # explicitly v1 format compatibility # NOTE: processed_files_num(v2) = files_processed_download(v1) + files_processed_copy(v1) - # check wrapper.UpdateStatus.convert_to_v1_StatusProgress for more details. + # check api_types.UpdateStatus.convert_to_v1_StatusProgress for more details. ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="autoware", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, version="123.x", - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, total_regular_files=123456, files_processed_download=100, file_size_processed_download=400, @@ -343,18 +343,18 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): file_size_processed_copy=56, download_bytes=789, regular_files_processed=123, - total_elapsed_time=wrapper.Duration(seconds=123), + total_elapsed_time=api_types.Duration(seconds=123), ), ), ), - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p1", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, version="123.x", - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, total_regular_files=123456, files_processed_download=100, file_size_processed_download=400, @@ -362,29 +362,29 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): file_size_processed_copy=56, download_bytes=789, regular_files_processed=123, - total_elapsed_time=wrapper.Duration(seconds=123), + total_elapsed_time=api_types.Duration(seconds=123), ), ), ), - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - result=wrapper.FailureType.NO_FAILURE, - status=wrapper.Status( + result=api_types.FailureType.NO_FAILURE, + status=api_types.Status( version="123.x", - status=wrapper.StatusOta.SUCCESS, + status=api_types.StatusOta.SUCCESS, ), ), ], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.UPDATING, - failure_type=wrapper.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.UPDATING, + failure_type=api_types.FailureType.NO_FAILURE, firmware_version="123.x", - update_status=wrapper.UpdateStatus( + update_status=api_types.UpdateStatus( update_firmware_version="789.x", - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES, - total_elapsed_time=wrapper.Duration(seconds=123), + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES, + total_elapsed_time=api_types.Duration(seconds=123), total_files_num=123456, processed_files_num=123, processed_files_size=456, @@ -393,15 +393,15 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): downloaded_files_size=400, ), ), - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.UPDATING, - failure_type=wrapper.FailureType.NO_FAILURE, + ota_status=api_types.StatusOta.UPDATING, + failure_type=api_types.FailureType.NO_FAILURE, firmware_version="123.x", - update_status=wrapper.UpdateStatus( + update_status=api_types.UpdateStatus( update_firmware_version="789.x", - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES, - total_elapsed_time=wrapper.Duration(seconds=123), + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES, + total_elapsed_time=api_types.Duration(seconds=123), total_files_num=123456, processed_files_num=123, processed_files_size=456, @@ -417,9 +417,9 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture): ) async def test_export( self, - local_ecu_status: wrapper.StatusResponseEcuV2, - sub_ecus_status: List[wrapper.StatusResponse], - expected: wrapper.StatusResponse, + local_ecu_status: api_types.StatusResponseEcuV2, + sub_ecus_status: List[api_types.StatusResponse], + expected: api_types.StatusResponse, ): # --- prepare --- # await self.ecu_storage.update_from_local_ecu(local_ecu_status) @@ -438,34 +438,34 @@ async def test_export( # case 1: ( # local ECU status: UPDATING, requires network - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.UPDATING, - update_status=wrapper.UpdateStatus( - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES + ota_status=api_types.StatusOta.UPDATING, + update_status=api_types.UpdateStatus( + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES ), ), # sub ECUs status [ - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.FAILURE, + ota_status=api_types.StatusOta.FAILURE, ), ], ), # p2: updating, doesn't require network - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.POST_PROCESSING, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.POST_PROCESSING, ), ), ) @@ -486,32 +486,32 @@ async def test_export( # case 2: ( # local ECU status: SUCCESS - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.SUCCESS, + ota_status=api_types.StatusOta.SUCCESS, ), # sub ECUs status [ # p1: FAILURE - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.FAILURE, + ota_status=api_types.StatusOta.FAILURE, ), ], ), # p2: updating, requires network - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, ), ), ) @@ -533,8 +533,8 @@ async def test_export( ) async def test_overall_ecu_status_report_generation( self, - local_ecu_status: wrapper.StatusResponseEcuV2, - sub_ecus_status: List[wrapper.StatusResponse], + local_ecu_status: api_types.StatusResponseEcuV2, + sub_ecus_status: List[api_types.StatusResponse], properties_dict: Dict[str, Any], ): # --- prepare --- # @@ -559,32 +559,32 @@ async def test_overall_ecu_status_report_generation( # based on the status change of ECUs that accept update request. ( # local ECU status: FAILED - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.FAILURE, + ota_status=api_types.StatusOta.FAILURE, ), # sub ECUs status [ # p1: FAILED - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.FAILURE, + ota_status=api_types.StatusOta.FAILURE, ), ], ), # p2: UPDATING - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, ), ), ) @@ -610,33 +610,33 @@ async def test_overall_ecu_status_report_generation( # based on the status change of ECUs that accept update request. ( # local ECU status: UPDATING - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="autoware", - ota_status=wrapper.StatusOta.UPDATING, - update_status=wrapper.UpdateStatus( - phase=wrapper.UpdatePhase.DOWNLOADING_OTA_FILES, + ota_status=api_types.StatusOta.UPDATING, + update_status=api_types.UpdateStatus( + phase=api_types.UpdatePhase.DOWNLOADING_OTA_FILES, ), ), # sub ECUs status [ # p1: FAILED - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p1"], ecu_v2=[ - wrapper.StatusResponseEcuV2( + api_types.StatusResponseEcuV2( ecu_id="p1", - ota_status=wrapper.StatusOta.FAILURE, + ota_status=api_types.StatusOta.FAILURE, ), ], ), # p2: SUCCESS - wrapper.StatusResponse( + api_types.StatusResponse( available_ecu_ids=["p2"], ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="p2", - status=wrapper.Status( - status=wrapper.StatusOta.SUCCESS, + status=api_types.Status( + status=api_types.StatusOta.SUCCESS, ), ) ], @@ -658,8 +658,8 @@ async def test_overall_ecu_status_report_generation( ) async def test_on_receive_update_request( self, - local_ecu_status: wrapper.StatusResponseEcuV2, - sub_ecus_status: List[wrapper.StatusResponse], + local_ecu_status: api_types.StatusResponseEcuV2, + sub_ecus_status: List[api_types.StatusResponse], ecus_accept_update_request: List[str], properties_dict: Dict[str, Any], ): @@ -710,10 +710,10 @@ class TestOTAClientServiceStub: @staticmethod async def _subecu_accept_update_request(ecu_id, *args, **kwargs): - return wrapper.UpdateResponse( + return api_types.UpdateResponse( ecu=[ - wrapper.UpdateResponseEcu( - ecu_id=ecu_id, result=wrapper.FailureType.NO_FAILURE + api_types.UpdateResponseEcu( + ecu_id=ecu_id, result=api_types.FailureType.NO_FAILURE ) ] ) @@ -741,11 +741,11 @@ async def setup_test( await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) # ensure the task stopping # --- mocker --- # - self.otaclient_wrapper = mocker.MagicMock(spec=OTAServicer) + self.otaclient_api_types = mocker.MagicMock(spec=OTAServicer) self.ecu_status_tracker = mocker.MagicMock() self.otaproxy_launcher = mocker.MagicMock(spec=OTAProxyLauncher) # mock OTAClientCall, make update_call return success on any update dispatches to subECUs - self.otaclient_call = mocker.AsyncMock(spec=OtaClientCall) + self.otaclient_call = mocker.AsyncMock(spec=OTAClientCall) self.otaclient_call.update_call = mocker.AsyncMock( wraps=self._subecu_accept_update_request ) @@ -761,7 +761,7 @@ async def setup_test( ) mocker.patch( f"{cfg.OTACLIENT_STUB_MODULE_PATH}.OTAServicer", - mocker.MagicMock(return_value=self.otaclient_wrapper), + mocker.MagicMock(return_value=self.otaclient_api_types), ) mocker.patch( f"{cfg.OTACLIENT_STUB_MODULE_PATH}._ECUTracker", @@ -772,7 +772,7 @@ async def setup_test( mocker.MagicMock(return_value=self.otaproxy_launcher), ) mocker.patch( - f"{cfg.OTACLIENT_STUB_MODULE_PATH}.OtaClientCall", self.otaclient_call + f"{cfg.OTACLIENT_STUB_MODULE_PATH}.OTAClientCall", self.otaclient_call ) # --- start the OTAClientServiceStub --- # @@ -845,15 +845,15 @@ async def test__otaclient_control_flags_managing(self): ( # update request for autoware, p1 ecus ( - wrapper.UpdateRequest( + api_types.UpdateRequest( ecu=[ - wrapper.UpdateRequestEcu( + api_types.UpdateRequestEcu( ecu_id="autoware", version="789.x", url="url", cookies="cookies", ), - wrapper.UpdateRequestEcu( + api_types.UpdateRequestEcu( ecu_id="p1", version="789.x", url="url", @@ -865,24 +865,24 @@ async def test__otaclient_control_flags_managing(self): # NOTE: order matters! # update request dispatching to subECUs happens first, # and then to the local ECU. - wrapper.UpdateResponse( + api_types.UpdateResponse( ecu=[ - wrapper.UpdateResponseEcu( + api_types.UpdateResponseEcu( ecu_id="p1", - result=wrapper.FailureType.NO_FAILURE, + result=api_types.FailureType.NO_FAILURE, ), - wrapper.UpdateResponseEcu( + api_types.UpdateResponseEcu( ecu_id="autoware", - result=wrapper.FailureType.NO_FAILURE, + result=api_types.FailureType.NO_FAILURE, ), ] ), ), # update only p2 ( - wrapper.UpdateRequest( + api_types.UpdateRequest( ecu=[ - wrapper.UpdateRequestEcu( + api_types.UpdateRequestEcu( ecu_id="p2", version="789.x", url="url", @@ -891,11 +891,11 @@ async def test__otaclient_control_flags_managing(self): ] ), {"p2"}, - wrapper.UpdateResponse( + api_types.UpdateResponse( ecu=[ - wrapper.UpdateResponseEcu( + api_types.UpdateResponseEcu( ecu_id="p2", - result=wrapper.FailureType.NO_FAILURE, + result=api_types.FailureType.NO_FAILURE, ), ] ), @@ -904,13 +904,15 @@ async def test__otaclient_control_flags_managing(self): ) async def test_update_normal( self, - update_request: wrapper.UpdateRequest, + update_request: api_types.UpdateRequest, update_target_ids: Set[str], - expected: wrapper.UpdateResponse, + expected: api_types.UpdateResponse, ): # --- setup --- # - self.otaclient_wrapper.dispatch_update.return_value = wrapper.UpdateResponseEcu( - ecu_id=self.ecu_info.ecu_id, result=wrapper.FailureType.NO_FAILURE + self.otaclient_api_types.dispatch_update.return_value = ( + api_types.UpdateResponseEcu( + ecu_id=self.ecu_info.ecu_id, result=api_types.FailureType.NO_FAILURE + ) ) # --- execution --- # @@ -925,27 +927,29 @@ async def test_update_normal( async def test_update_local_ecu_busy(self): # --- preparation --- # - self.otaclient_wrapper.dispatch_update.return_value = wrapper.UpdateResponseEcu( - ecu_id="autoware", result=wrapper.FailureType.RECOVERABLE + self.otaclient_api_types.dispatch_update.return_value = ( + api_types.UpdateResponseEcu( + ecu_id="autoware", result=api_types.FailureType.RECOVERABLE + ) ) - update_request_ecu = wrapper.UpdateRequestEcu( + update_request_ecu = api_types.UpdateRequestEcu( ecu_id="autoware", version="version", url="url", cookies="cookies" ) # --- execution --- # resp = await self.otaclient_service_stub.update( - wrapper.UpdateRequest(ecu=[update_request_ecu]) + api_types.UpdateRequest(ecu=[update_request_ecu]) ) # --- assertion --- # - assert resp == wrapper.UpdateResponse( + assert resp == api_types.UpdateResponse( ecu=[ - wrapper.UpdateResponseEcu( + api_types.UpdateResponseEcu( ecu_id="autoware", - result=wrapper.FailureType.RECOVERABLE, + result=api_types.FailureType.RECOVERABLE, ) ] ) - self.otaclient_wrapper.dispatch_update.assert_called_once_with( + self.otaclient_api_types.dispatch_update.assert_called_once_with( update_request_ecu ) diff --git a/tests/test_update_stats.py b/tests/test_otaclient/test_update_stats.py similarity index 100% rename from tests/test_update_stats.py rename to tests/test_otaclient/test_update_stats.py diff --git a/tests/test_otaclient_api/__init__.py b/tests/test_otaclient_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_otaclient_api/test_v2/__init__.py b/tests/test_otaclient_api/test_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_ota_client_call.py b/tests/test_otaclient_api/test_v2/test_apli_caller.py similarity index 90% rename from tests/test_ota_client_call.py rename to tests/test_otaclient_api/test_v2/test_apli_caller.py index f883851e8..96298acfe 100644 --- a/tests/test_ota_client_call.py +++ b/tests/test_otaclient_api/test_v2/test_apli_caller.py @@ -17,8 +17,10 @@ import pytest import pytest_asyncio -from otaclient.app.ota_client_call import ECUNoResponse, OtaClientCall -from otaclient.app.proto import v2, v2_grpc, wrapper +from otaclient_api.v2 import otaclient_v2_pb2 as v2 +from otaclient_api.v2 import otaclient_v2_pb2_grpc as v2_grpc +from otaclient_api.v2 import types as api_types +from otaclient_api.v2.api_caller import ECUNoResponse, OTAClientCall from tests.utils import compare_message @@ -133,10 +135,10 @@ async def dummy_ota_client_service(self): await server.stop(None) async def test_update_call(self, dummy_ota_client_service): - _req = wrapper.UpdateRequest.convert( + _req = api_types.UpdateRequest.convert( _DummyOTAClientService.DUMMY_UPDATE_REQUEST ) - _response = await OtaClientCall.update_call( + _response = await OTAClientCall.update_call( ecu_id=self.DUMMY_ECU_ID, ecu_ipaddr=self.OTA_CLIENT_SERVICE_IP, ecu_port=self.OTA_CLIENT_SERVICE_PORT, @@ -147,10 +149,10 @@ async def test_update_call(self, dummy_ota_client_service): ) async def test_rollback_call(self, dummy_ota_client_service): - _req = wrapper.RollbackRequest.convert( + _req = api_types.RollbackRequest.convert( _DummyOTAClientService.DUMMY_ROLLBACK_REQUEST ) - _response = await OtaClientCall.rollback_call( + _response = await OTAClientCall.rollback_call( ecu_id=self.DUMMY_ECU_ID, ecu_ipaddr=self.OTA_CLIENT_SERVICE_IP, ecu_port=self.OTA_CLIENT_SERVICE_PORT, @@ -161,22 +163,22 @@ async def test_rollback_call(self, dummy_ota_client_service): ) async def test_status_call(self, dummy_ota_client_service): - _response = await OtaClientCall.status_call( + _response = await OTAClientCall.status_call( ecu_id=self.DUMMY_ECU_ID, ecu_ipaddr=self.OTA_CLIENT_SERVICE_IP, ecu_port=self.OTA_CLIENT_SERVICE_PORT, - request=wrapper.StatusRequest(), + request=api_types.StatusRequest(), ) assert _response is not None compare_message(_response.export_pb(), _DummyOTAClientService.DUMMY_STATUS) async def test_update_call_no_response(self): - _req = wrapper.UpdateRequest.convert( + _req = api_types.UpdateRequest.convert( _DummyOTAClientService.DUMMY_UPDATE_REQUEST ) with pytest.raises(ECUNoResponse): - await OtaClientCall.update_call( + await OTAClientCall.update_call( ecu_id=self.DUMMY_ECU_ID, ecu_ipaddr=self.OTA_CLIENT_SERVICE_IP, ecu_port=self.OTA_CLIENT_SERVICE_PORT, diff --git a/tests/test_proto/test_otaclient_pb2_wrapper.py b/tests/test_otaclient_api/test_v2/test_types.py similarity index 70% rename from tests/test_proto/test_otaclient_pb2_wrapper.py rename to tests/test_otaclient_api/test_v2/test_types.py index 87c63fd0f..b3d74f8dc 100644 --- a/tests/test_proto/test_otaclient_pb2_wrapper.py +++ b/tests/test_otaclient_api/test_v2/test_types.py @@ -18,7 +18,8 @@ import pytest from google.protobuf.duration_pb2 import Duration as _Duration -from otaclient.app.proto import v2, wrapper +from otaclient_api.v2 import otaclient_v2_pb2 as v2 +from otaclient_api.v2 import types as api_types from tests.utils import compare_message @@ -32,12 +33,14 @@ total_regular_files=123456, elapsed_time_download=_Duration(seconds=1, nanos=5678), ), - wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, total_regular_files=123456, - elapsed_time_download=wrapper.Duration.from_nanoseconds(1_000_005_678), + elapsed_time_download=api_types.Duration.from_nanoseconds( + 1_000_005_678 + ), ), - wrapper.StatusProgress, + api_types.StatusProgress, ), # UpdateRequest: with protobuf repeated composite field ( @@ -47,13 +50,13 @@ v2.UpdateRequestEcu(ecu_id="ecu_2"), ] ), - wrapper.UpdateRequest( + api_types.UpdateRequest( ecu=[ - wrapper.UpdateRequestEcu(ecu_id="ecu_1"), - wrapper.UpdateRequestEcu(ecu_id="ecu_2"), + api_types.UpdateRequestEcu(ecu_id="ecu_1"), + api_types.UpdateRequestEcu(ecu_id="ecu_2"), ] ), - wrapper.UpdateRequest, + api_types.UpdateRequest, ), # UpdateRequest: with protobuf repeated composite field, ( @@ -63,13 +66,13 @@ v2.UpdateRequestEcu(ecu_id="ecu_2"), ] ), - wrapper.UpdateRequest( + api_types.UpdateRequest( ecu=[ - wrapper.UpdateRequestEcu(ecu_id="ecu_1"), - wrapper.UpdateRequestEcu(ecu_id="ecu_2"), + api_types.UpdateRequestEcu(ecu_id="ecu_1"), + api_types.UpdateRequestEcu(ecu_id="ecu_2"), ] ), - wrapper.UpdateRequest, + api_types.UpdateRequest, ), # StatusResponse: multiple layer nested message, multiple protobuf message types ( @@ -100,29 +103,29 @@ ], available_ecu_ids=["ecu_1", "ecu_2"], ), - wrapper.StatusResponse( + api_types.StatusResponse( ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="ecu_1", - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, total_regular_files=123456, - elapsed_time_copy=wrapper.Duration.from_nanoseconds( + elapsed_time_copy=api_types.Duration.from_nanoseconds( 1_000_056_789 ), ), ), ), - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id="ecu_2", - status=wrapper.Status( - status=wrapper.StatusOta.UPDATING, - progress=wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR, + status=api_types.Status( + status=api_types.StatusOta.UPDATING, + progress=api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR, total_regular_files=456789, - elapsed_time_copy=wrapper.Duration.from_nanoseconds( + elapsed_time_copy=api_types.Duration.from_nanoseconds( 1_000_012_345 ), ), @@ -131,14 +134,14 @@ ], available_ecu_ids=["ecu_1", "ecu_2"], ), - wrapper.StatusResponse, + api_types.StatusResponse, ), ), ) def test_convert_message( origin_msg, - converted_msg: wrapper.MessageWrapper, - wrapper_type: Type[wrapper.MessageWrapper], + converted_msg: api_types.MessageWrapper, + wrapper_type: Type[api_types.MessageWrapper], ): # ------ converting message ------ # _converted = wrapper_type.convert(origin_msg) @@ -155,13 +158,13 @@ class Test_enum_wrapper_cooperate: def test_direct_compare(self): """protobuf enum and wrapper enum can compare directly.""" _protobuf_enum = v2.UPDATING - _wrapped = wrapper.StatusOta.UPDATING + _wrapped = api_types.StatusOta.UPDATING assert _protobuf_enum == _wrapped def test_assign_to_protobuf_message(self): """wrapper enum can be directly assigned in protobuf message.""" l, r = v2.StatusProgress(phase=v2.REGULAR), v2.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR.value, + phase=api_types.StatusProgressPhase.REGULAR.value, # type: ignore ) compare_message(l, r) @@ -169,8 +172,8 @@ def test_used_in_message_wrapper(self): """wrapper enum can be exported.""" l, r = ( v2.StatusProgress(phase=v2.REGULAR), - wrapper.StatusProgress( - phase=wrapper.StatusProgressPhase.REGULAR + api_types.StatusProgress( + phase=api_types.StatusProgressPhase.REGULAR ).export_pb(), ) compare_message(l, r) @@ -178,6 +181,6 @@ def test_used_in_message_wrapper(self): def test_converted_from_protobuf_enum(self): """wrapper enum can be converted from and to protobuf enum.""" _protobuf_enum = v2.REGULAR - _converted = wrapper.StatusProgressPhase(_protobuf_enum) + _converted = api_types.StatusProgressPhase(_protobuf_enum) assert _protobuf_enum == _converted - assert _converted == wrapper.StatusProgressPhase.REGULAR + assert _converted == api_types.StatusProgressPhase.REGULAR diff --git a/tests/test_otaclient_common/__init__.py b/tests/test_otaclient_common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_common.py b/tests/test_otaclient_common/test_common.py similarity index 72% rename from tests/test_common.py rename to tests/test_otaclient_common/test_common.py index d745faf76..28589fe22 100644 --- a/tests/test_common.py +++ b/tests/test_otaclient_common/test_common.py @@ -13,12 +13,12 @@ # limitations under the License. +from __future__ import annotations + import logging import os -import random import subprocess import time -from functools import partial from hashlib import sha256 from multiprocessing import Process from pathlib import Path @@ -26,13 +26,10 @@ import pytest -from otaclient.app.common import ( - RetryTaskMap, - RetryTaskMapInterrupted, +from otaclient_common.common import ( copytree_identical, ensure_otaproxy_start, file_sha256, - get_backoff, re_symlink_atomic, read_str_from_file, subprocess_call, @@ -231,122 +228,6 @@ def test_symlink_to_directory(self, tmp_path: Path): assert _symlink.is_symlink() and os.readlink(_symlink) == str(_target) -class _RetryTaskMapTestErr(Exception): - """""" - - -class TestRetryTaskMap: - WAIT_CONST = 100_000_000 - TASKS_COUNT = 2000 - MAX_CONCURRENT = 128 - DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT = 6 # seconds - MAX_WAIT_BEFORE_SUCCESS = 10 - - @pytest.fixture(autouse=True) - def setup(self): - self._start_time = time.time() - self._success_wait_dict = { - idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS) - for idx in range(self.TASKS_COUNT) - } - self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)] - - def workload_aways_failed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) - raise _RetryTaskMapTestErr - - def workload_failed_and_then_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) - if time.time() > self._start_time + self._success_wait_dict[idx]: - self._succeeded_tasks[idx] = True - return idx - raise _RetryTaskMapTestErr - - def workload_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) - self._succeeded_tasks[idx] = True - return idx - - def test_retry_keep_failing_timeout(self): - _keep_failing_timer = time.time() - with pytest.raises(RetryTaskMapInterrupted): - _mapper = RetryTaskMap( - backoff_func=partial(get_backoff, factor=0.1, _max=1), - max_concurrent=self.MAX_CONCURRENT, - max_retry=0, # we are testing keep failing timeout here - ) - for done_task in _mapper.map( - self.workload_aways_failed, range(self.TASKS_COUNT) - ): - if not done_task.fut.exception(): - # reset the failing timer on one succeeded task - _keep_failing_timer = time.time() - continue - if ( - time.time() - _keep_failing_timer - > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT - ): - logger.error( - f"RetryTaskMap successfully failed after keep failing in {self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT}s" - ) - _mapper.shutdown(raise_last_exc=True) - - def test_retry_exceed_retry_limit(self): - with pytest.raises(RetryTaskMapInterrupted): - _mapper = RetryTaskMap( - backoff_func=partial(get_backoff, factor=0.1, _max=1), - max_concurrent=self.MAX_CONCURRENT, - max_retry=3, - ) - for _ in _mapper.map(self.workload_aways_failed, range(self.TASKS_COUNT)): - pass - - def test_retry_finally_succeeded(self): - _keep_failing_timer = time.time() - - _mapper = RetryTaskMap( - backoff_func=partial(get_backoff, factor=0.1, _max=1), - max_concurrent=self.MAX_CONCURRENT, - max_retry=0, # we are testing keep failing timeout here - ) - for done_task in _mapper.map( - self.workload_failed_and_then_succeed, range(self.TASKS_COUNT) - ): - # task successfully finished - if not done_task.fut.exception(): - # reset the failing timer on one succeeded task - _keep_failing_timer = time.time() - continue - - if ( - time.time() - _keep_failing_timer - > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT - ): - _mapper.shutdown(raise_last_exc=True) - assert all(self._succeeded_tasks) - - def test_succeeded_in_one_try(self): - _keep_failing_timer = time.time() - _mapper = RetryTaskMap( - backoff_func=partial(get_backoff, factor=0.1, _max=1), - max_concurrent=self.MAX_CONCURRENT, - max_retry=0, # we are testing keep failing timeout here - ) - for done_task in _mapper.map(self.workload_succeed, range(self.TASKS_COUNT)): - # task successfully finished - if not done_task.fut.exception(): - # reset the failing timer on one succeeded task - _keep_failing_timer = time.time() - continue - - if ( - time.time() - _keep_failing_timer - > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT - ): - _mapper.shutdown(raise_last_exc=True) - assert all(self._succeeded_tasks) - - class Test_ensure_otaproxy_start: DUMMY_SERVER_ADDR, DUMMY_SERVER_PORT = "127.0.0.1", 18888 DUMMY_SERVER_URL = f"http://{DUMMY_SERVER_ADDR}:{DUMMY_SERVER_PORT}" diff --git a/tests/test_downloader.py b/tests/test_otaclient_common/test_downloader.py similarity index 98% rename from tests/test_downloader.py rename to tests/test_otaclient_common/test_downloader.py index b11324d53..4af3d23c1 100644 --- a/tests/test_downloader.py +++ b/tests/test_otaclient_common/test_downloader.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio import logging import threading @@ -24,8 +26,8 @@ import requests import requests_mock -from otaclient.app.common import file_sha256, urljoin_ensure_base -from otaclient.app.downloader import ( +from otaclient_common.common import file_sha256, urljoin_ensure_base +from otaclient_common.downloader import ( ChunkStreamingError, DestinationNotAvailableError, Downloader, diff --git a/tests/test_persist_file_handling.py b/tests/test_otaclient_common/test_persist_file_handling.py similarity index 99% rename from tests/test_persist_file_handling.py rename to tests/test_otaclient_common/test_persist_file_handling.py index cc00a0c0d..1b9823811 100644 --- a/tests/test_persist_file_handling.py +++ b/tests/test_otaclient_common/test_persist_file_handling.py @@ -19,8 +19,8 @@ import stat from pathlib import Path -from otaclient._utils import replace_root -from otaclient.app.common import PersistFilesHandler +from otaclient_common import replace_root +from otaclient_common.persist_file_handling import PersistFilesHandler def create_files(tmp_path: Path): diff --git a/tests/test_proto/__init__.py b/tests/test_otaclient_common/test_proto_wrapper/__init__.py similarity index 100% rename from tests/test_proto/__init__.py rename to tests/test_otaclient_common/test_proto_wrapper/__init__.py diff --git a/tests/test_proto/example.proto b/tests/test_otaclient_common/test_proto_wrapper/example.proto similarity index 100% rename from tests/test_proto/example.proto rename to tests/test_otaclient_common/test_proto_wrapper/example.proto diff --git a/tests/test_proto/example_pb2.py b/tests/test_otaclient_common/test_proto_wrapper/example_pb2.py similarity index 100% rename from tests/test_proto/example_pb2.py rename to tests/test_otaclient_common/test_proto_wrapper/example_pb2.py diff --git a/tests/test_proto/example_pb2.pyi b/tests/test_otaclient_common/test_proto_wrapper/example_pb2.pyi similarity index 100% rename from tests/test_proto/example_pb2.pyi rename to tests/test_otaclient_common/test_proto_wrapper/example_pb2.pyi diff --git a/tests/test_proto/example_pb2_wrapper.py b/tests/test_otaclient_common/test_proto_wrapper/example_pb2_wrapper.py similarity index 96% rename from tests/test_proto/example_pb2_wrapper.py rename to tests/test_otaclient_common/test_proto_wrapper/example_pb2_wrapper.py index 45ae7cd04..60f14e297 100644 --- a/tests/test_proto/example_pb2_wrapper.py +++ b/tests/test_otaclient_common/test_proto_wrapper/example_pb2_wrapper.py @@ -13,12 +13,14 @@ # limitations under the License. +from __future__ import annotations + from typing import Iterable as _Iterable from typing import Mapping as _Mapping from typing import Optional as _Optional from typing import Union as _Union -from otaclient.app.proto.wrapper import ( +from otaclient_common.proto_wrapper import ( Duration, EnumWrapper, MessageMapContainer, diff --git a/tests/test_proto/test_proto_wrapper.py b/tests/test_otaclient_common/test_proto_wrapper/test_proto_wrapper.py similarity index 97% rename from tests/test_proto/test_proto_wrapper.py rename to tests/test_otaclient_common/test_proto_wrapper/test_proto_wrapper.py index 269ae8245..876b81a5d 100644 --- a/tests/test_proto/test_proto_wrapper.py +++ b/tests/test_otaclient_common/test_proto_wrapper/test_proto_wrapper.py @@ -13,12 +13,14 @@ # limitations under the License. -from typing import Any, Dict +from __future__ import annotations + +from typing import Any import pytest from google.protobuf.duration_pb2 import Duration as _pb2_Duration -from otaclient.app.proto import wrapper as proto_wrapper +from otaclient_common import proto_wrapper from tests.utils import compare_message from . import example_pb2 as pb2 @@ -61,7 +63,7 @@ ), ) def test_default_value_behavior( - input_wrapper_inst: proto_wrapper.MessageWrapper, expected_dict: Dict[str, Any] + input_wrapper_inst: proto_wrapper.MessageWrapper, expected_dict: dict[str, Any] ): for _field_name in input_wrapper_inst._fields: _value = getattr(input_wrapper_inst, _field_name) diff --git a/tests/test_otaclient_common/test_retry_task_map.py b/tests/test_otaclient_common/test_retry_task_map.py new file mode 100644 index 000000000..575d2450d --- /dev/null +++ b/tests/test_otaclient_common/test_retry_task_map.py @@ -0,0 +1,144 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +import random +import time +from functools import partial + +import pytest + +from otaclient_common.common import get_backoff +from otaclient_common.retry_task_map import RetryTaskMap, RetryTaskMapInterrupted + +logger = logging.getLogger(__name__) + + +class _RetryTaskMapTestErr(Exception): + """""" + + +class TestRetryTaskMap: + WAIT_CONST = 100_000_000 + TASKS_COUNT = 2000 + MAX_CONCURRENT = 128 + DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT = 6 # seconds + MAX_WAIT_BEFORE_SUCCESS = 10 + + @pytest.fixture(autouse=True) + def setup(self): + self._start_time = time.time() + self._success_wait_dict = { + idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS) + for idx in range(self.TASKS_COUNT) + } + self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)] + + def workload_aways_failed(self, idx: int) -> int: + time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + raise _RetryTaskMapTestErr + + def workload_failed_and_then_succeed(self, idx: int) -> int: + time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + if time.time() > self._start_time + self._success_wait_dict[idx]: + self._succeeded_tasks[idx] = True + return idx + raise _RetryTaskMapTestErr + + def workload_succeed(self, idx: int) -> int: + time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + self._succeeded_tasks[idx] = True + return idx + + def test_retry_keep_failing_timeout(self): + _keep_failing_timer = time.time() + with pytest.raises(RetryTaskMapInterrupted): + _mapper = RetryTaskMap( + backoff_func=partial(get_backoff, factor=0.1, _max=1), + max_concurrent=self.MAX_CONCURRENT, + max_retry=0, # we are testing keep failing timeout here + ) + for done_task in _mapper.map( + self.workload_aways_failed, range(self.TASKS_COUNT) + ): + if not done_task.fut.exception(): + # reset the failing timer on one succeeded task + _keep_failing_timer = time.time() + continue + if ( + time.time() - _keep_failing_timer + > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT + ): + logger.error( + f"RetryTaskMap successfully failed after keep failing in {self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT}s" + ) + _mapper.shutdown(raise_last_exc=True) + + def test_retry_exceed_retry_limit(self): + with pytest.raises(RetryTaskMapInterrupted): + _mapper = RetryTaskMap( + backoff_func=partial(get_backoff, factor=0.1, _max=1), + max_concurrent=self.MAX_CONCURRENT, + max_retry=3, + ) + for _ in _mapper.map(self.workload_aways_failed, range(self.TASKS_COUNT)): + pass + + def test_retry_finally_succeeded(self): + _keep_failing_timer = time.time() + + _mapper = RetryTaskMap( + backoff_func=partial(get_backoff, factor=0.1, _max=1), + max_concurrent=self.MAX_CONCURRENT, + max_retry=0, # we are testing keep failing timeout here + ) + for done_task in _mapper.map( + self.workload_failed_and_then_succeed, range(self.TASKS_COUNT) + ): + # task successfully finished + if not done_task.fut.exception(): + # reset the failing timer on one succeeded task + _keep_failing_timer = time.time() + continue + + if ( + time.time() - _keep_failing_timer + > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT + ): + _mapper.shutdown(raise_last_exc=True) + assert all(self._succeeded_tasks) + + def test_succeeded_in_one_try(self): + _keep_failing_timer = time.time() + _mapper = RetryTaskMap( + backoff_func=partial(get_backoff, factor=0.1, _max=1), + max_concurrent=self.MAX_CONCURRENT, + max_retry=0, # we are testing keep failing timeout here + ) + for done_task in _mapper.map(self.workload_succeed, range(self.TASKS_COUNT)): + # task successfully finished + if not done_task.fut.exception(): + # reset the failing timer on one succeeded task + _keep_failing_timer = time.time() + continue + + if ( + time.time() - _keep_failing_timer + > self.DOWNLOAD_GROUP_NO_SUCCESS_RETRY_TIMEOUT + ): + _mapper.shutdown(raise_last_exc=True) + assert all(self._succeeded_tasks) diff --git a/tests/utils.py b/tests/utils.py index f3674a763..8a4e5bd2a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio import logging import os @@ -27,8 +29,9 @@ import zstandard from google.protobuf.message import Message as _Message -from otaclient.app.common import file_sha256 -from otaclient.app.proto import v2_grpc, wrapper +from otaclient_api.v2 import otaclient_v2_pb2_grpc as v2_grpc +from otaclient_api.v2 import types as api_types +from otaclient_common.common import file_sha256 logger = logging.getLogger(__name__) @@ -113,13 +116,13 @@ def compare_dir(left: Path, right: Path): class DummySubECU: - SUCCESS_RESPONSE = wrapper.Status( - status=wrapper.StatusOta.SUCCESS, - failure=wrapper.FailureType.NO_FAILURE, + SUCCESS_RESPONSE = api_types.Status( + status=api_types.StatusOta.SUCCESS, + failure=api_types.FailureType.NO_FAILURE, ) - UPDATING_RESPONSE = wrapper.Status( - status=wrapper.StatusOta.UPDATING, - failure=wrapper.FailureType.NO_FAILURE, + UPDATING_RESPONSE = api_types.Status( + status=api_types.StatusOta.UPDATING, + failure=api_types.FailureType.NO_FAILURE, ) UPDATE_TIME_COST = 6 REBOOT_TIME_COST = 1 @@ -138,9 +141,9 @@ def status(self): # update not yet started if self._receive_update_time is None: logger.debug(f"{self.ecu_id=}, update not yet started") - res = wrapper.StatusResponse( + res = api_types.StatusResponse( ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id=self.ecu_id, status=self.SUCCESS_RESPONSE, ) @@ -155,9 +158,9 @@ def status(self): logger.debug( f"update finished for {self.ecu_id=}, {self._receive_update_time=}, {time.time()=}" ) - res = wrapper.StatusResponse( + res = api_types.StatusResponse( ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id=self.ecu_id, status=self.SUCCESS_RESPONSE, ) @@ -172,9 +175,9 @@ def status(self): return None # updating logger.debug(f"{self.ecu_id=}, updating") - res = wrapper.StatusResponse( + res = api_types.StatusResponse( ecu=[ - wrapper.StatusResponseEcu( + api_types.StatusResponseEcu( ecu_id=self.ecu_id, status=self.UPDATING_RESPONSE, ) diff --git a/tools/build_image.sh b/tools/build_image.sh deleted file mode 100644 index 9cad89ec4..000000000 --- a/tools/build_image.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -set -eux - -docker build -t ota-image -f ./docker/build_image/Dockerfile . -id=$(docker create -it ota-image) -ota_image_dir="ota-image.$(date +%Y%m%d%H%M%S)" -mkdir ${ota_image_dir} - -cd ${ota_image_dir} -docker export ${id} > ota-image.tar -docker rm ${id} -mkdir data -sudo tar xf ota-image.tar -C data -git clone https://github.com/tier4/ota-metadata - -cp ../tests/keys/sign.pem . -cp ota-metadata/metadata/persistents.txt . - -sudo python3 ota-metadata/metadata/ota_metadata/metadata_gen.py --target-dir data --ignore-file ota-metadata/metadata/ignore.txt -sudo python3 ota-metadata/metadata/ota_metadata/metadata_sign.py --sign-key ../tests/keys/sign.key --cert-file sign.pem --persistent-file persistents.txt --rootfs-directory data - -sudo chown -R $(whoami) data diff --git a/tools/emulator/README.md b/tools/emulator/README.md deleted file mode 100644 index a536010de..000000000 --- a/tools/emulator/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# ota-client emulator - -This tool mimics ota-client behavior. -This tool can be used to develop software which requests ota-client. - -## Usage - -```bash -python main.py --config config.yml -``` diff --git a/tools/emulator/config.yml b/tools/emulator/config.yml deleted file mode 100644 index 8532f3185..000000000 --- a/tools/emulator/config.yml +++ /dev/null @@ -1,31 +0,0 @@ -ecus: -- main: true # main ecu or not. only one ecu should be main. - name: autoware # should be unique - status: INITIALIZED # INITIALIZED | SUCCESS | FAILURE | UPDATING | NO_CONNECTION - version: 123.456 # current version - time_to_update: 40 # in second - time_to_restart: 10 # in second - -- name: perception1 # should be unique - status: INITIALIZED # INITIALIZED | SUCCESS | FAILURE | UPDATING | NO_CONNECTION - version: abc.def # current version - time_to_update: 60 # in second - time_to_restart: 10 # in second - -- name: perception2 - status: INITIALIZED # INITIALIZED | SUCCESS | FAILURE | UPDATING | NO_CONNECTION - version: abc.def - time_to_update: 60 # in second - time_to_restart: 10 # in second - -- name: perception3 - status: INITIALIZED # INITIALIZED | SUCCESS | FAILURE | UPDATING | NO_CONNECTION - version: abc.def - time_to_update: 60 # in second - time_to_restart: 10 # in second - -#- name: perception4 -# status: INITIALIZED # INITIALIZED | SUCCESS | FAILURE | UPDATING | NO_CONNECTION -# version: abc.def -# time_to_update: 60 # in second -# time_to_restart: 10 # in second diff --git a/tools/emulator/ecu.py b/tools/emulator/ecu.py deleted file mode 100644 index afd8e2795..000000000 --- a/tools/emulator/ecu.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import time - -import log_setting -import otaclient_v2_pb2 as v2 -from configs import config as cfg - -logger = log_setting.get_logger( - __name__, cfg.LOG_LEVEL_TABLE.get(__name__, cfg.DEFAULT_LOG_LEVEL) -) - - -class Ecu: - TOTAL_REGULAR_FILES = 123456789 - TOTAL_REGULAR_FILE_SIZE = 987654321 - TIME_TO_UPDATE = 60 * 5 - TIME_TO_RESTART = 10 - - def __init__( - self, - is_main, - name, - status, - version, - time_to_update=TIME_TO_UPDATE, - time_to_restart=TIME_TO_RESTART, - ): - self._is_main = is_main - self._name = name - self._version = version - self._version_to_update = None - self._status = v2.Status() - self._status.status = v2.StatusOta.Value(status) - self._update_time = None - self._time_to_update = time_to_update - self._time_to_restart = time_to_restart - - def create(self): - return Ecu( - is_main=self._is_main, - name=self._name, - status=v2.StatusOta.Name(self._status.status), - version=self._version, - time_to_update=self._time_to_update, - time_to_restart=self._time_to_restart, - ) - - def change_to_success(self): - if self._status.status != v2.StatusOta.UPDATING: - logger.warning(f"current status: {v2.StatusOta.Name(self._status.status)}") - return - logger.info(f"change_to_success: {self._name=}") - self._status.status = v2.StatusOta.SUCCESS - self._version = self._version_to_update - - def update(self, response_ecu, version): - ecu = response_ecu - ecu.ecu_id = self._name - ecu.result = v2.FailureType.NO_FAILURE - - # update status - self._status = v2.Status() # reset - self._status.status = v2.StatusOta.UPDATING - self._version_to_update = version - self._update_time = time.time() - - def status(self, response_ecu): - ecu = response_ecu - ecu.ecu_id = self._name - ecu.result = v2.FailureType.NO_FAILURE - - try: - elapsed = time.time() - self._update_time - progress_rate = elapsed / self._time_to_update - should_restart = elapsed > (self._time_to_update + self._time_to_restart) - except TypeError: # when self._update_time is None - elapsed = 0 - progress_rate = 0 - should_restart = False - - # Main ecu waits for all sub ecu sucesss, while sub ecu transitions to - # success by itself. This code is intended to mimic that. - # The actual ecu updates, restarts and then transitions to success. - # In this code, after starting update and after time_to_update + - # time_to_restart elapsed, it transitions to success. - if not self._is_main and should_restart: - self.change_to_success() - else: - ecu.status.progress.CopyFrom(self._progress_rate_to_progress(progress_rate)) - - ecu.status.status = self._status.status - ecu.status.failure = v2.FailureType.NO_FAILURE - ecu.status.failure_reason = "" - ecu.status.version = self._version - - def _progress_rate_to_progress(self, rate): - progress = v2.StatusProgress() - if rate == 0: - progress.phase = v2.StatusProgressPhase.INITIAL - elif rate <= 0.01: - progress.phase = v2.StatusProgressPhase.METADATA - elif rate <= 0.02: - progress.phase = v2.StatusProgressPhase.DIRECTORY - elif rate <= 0.03: - progress.phase = v2.StatusProgressPhase.SYMLINK - elif rate <= 0.95: - progress.phase = v2.StatusProgressPhase.REGULAR - progress.total_regular_files = self.TOTAL_REGULAR_FILES - progress.regular_files_processed = int(self.TOTAL_REGULAR_FILES * rate) - - progress.files_processed_copy = int(progress.regular_files_processed * 0.4) - progress.files_processed_link = int(progress.regular_files_processed * 0.01) - progress.files_processed_download = ( - progress.regular_files_processed - - progress.files_processed_copy - - progress.files_processed_link - ) - size_processed = int(self.TOTAL_REGULAR_FILE_SIZE * rate) - progress.file_size_processed_copy = int(size_processed * 0.4) - progress.file_size_processed_link = int(size_processed * 0.01) - progress.file_size_processed_download = ( - size_processed - - progress.file_size_processed_copy - - progress.file_size_processed_link - ) - - progress.elapsed_time_copy.FromSeconds( - int(self._time_to_update * rate * 0.4) - ) - progress.elapsed_time_link.FromSeconds( - int(self._time_to_update * rate * 0.01) - ) - progress.elapsed_time_download.FromSeconds( - int(self._time_to_update * rate * 0.6) - ) - progress.errors_download = int(rate * 0.1) - progress.total_regular_file_size = self.TOTAL_REGULAR_FILE_SIZE - progress.total_elapsed_time.FromSeconds(int(self._time_to_update * rate)) - else: - progress.phase = v2.StatusProgressPhase.PERSISTENT - progress.total_regular_files = self.TOTAL_REGULAR_FILES - progress.regular_files_processed = self.TOTAL_REGULAR_FILES - - progress.files_processed_copy = int(progress.regular_files_processed * 0.4) - progress.files_processed_link = int(progress.regular_files_processed * 0.01) - progress.files_processed_download = ( - progress.regular_files_processed - - progress.files_processed_copy - - progress.files_processed_link - ) - size_processed = self.TOTAL_REGULAR_FILE_SIZE - progress.file_size_processed_copy = int(size_processed * 0.4) - progress.file_size_processed_link = int(size_processed * 0.01) - progress.file_size_processed_download = ( - size_processed - - progress.file_size_processed_copy - - progress.file_size_processed_link - ) - - progress.elapsed_time_copy.FromSeconds(int(self._time_to_update * 0.4)) - progress.elapsed_time_link.FromSeconds(int(self._time_to_update * 0.01)) - progress.elapsed_time_download.FromSeconds(int(self._time_to_update * 0.6)) - progress.errors_download = int(rate * 0.1) - progress.total_regular_file_size = self.TOTAL_REGULAR_FILE_SIZE - progress.total_elapsed_time.FromSeconds(self._time_to_update) - return progress diff --git a/tools/emulator/main.py b/tools/emulator/main.py deleted file mode 100644 index 8c859d54e..000000000 --- a/tools/emulator/main.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import time -from pathlib import Path - -import log_setting -import otaclient_v2_pb2 as v2 -import otaclient_v2_pb2_grpc as v2_grpc -import path_loader # noqa -import yaml -from configs import config as cfg -from configs import server_cfg -from ecu import Ecu -from ota_client_service import ( - OtaClientServiceV2, - service_start, - service_stop, - service_wait_for_termination, -) -from ota_client_stub import OtaClientStub - -logger = log_setting.get_logger( - __name__, cfg.LOG_LEVEL_TABLE.get(__name__, cfg.DEFAULT_LOG_LEVEL) -) - -DEFAULT_ECUS = [ - {"main": True, "id": "autoware", "status": "INITIALIZED", "version": "123.456"} -] - - -def main(config_file): - logger.info("started") - - server = None - - try: - config = yaml.safe_load(config_file.read_text()) - ecu_config = config["ecus"] - except Exception as e: - logger.warning(e) - logger.warning( - f"{config_file} couldn't be parsed. Default config is used instead." - ) - ecu_config = DEFAULT_ECUS - ecus = [] - logger.info(ecu_config) - for ecu in ecu_config: - e = Ecu( - is_main=ecu.get("main", False), - name=ecu.get("name", "autoware"), - status=ecu.get("status", "INITIALIZED"), - version=str(ecu.get("version", "")), - time_to_update=ecu.get("time_to_update"), - time_to_restart=ecu.get("time_to_restart"), - ) - ecus.append(e) - logger.info(ecus) - - def terminate(restart_time): - logger.info(f"{server=}") - service_stop(server) - logger.info(f"restarting. wait {restart_time}s.") - time.sleep(restart_time) - - while True: - ota_client_stub = OtaClientStub(ecus, terminate) - ota_client_service_v2 = OtaClientServiceV2(ota_client_stub) - - logger.info("starting grpc server.") - server = service_start( - f"localhost:{server_cfg.SERVER_PORT}", - [ - {"grpc": v2_grpc, "instance": ota_client_service_v2}, - ], - ) - - service_wait_for_termination(server) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--config", help="config.yml", default="config.yml") - args = parser.parse_args() - - logger.info(args) - - main(Path(args.config)) diff --git a/tools/emulator/ota_client_stub.py b/tools/emulator/ota_client_stub.py deleted file mode 100644 index c20d7c2ac..000000000 --- a/tools/emulator/ota_client_stub.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pathlib import Path -from threading import Thread, Timer - -import log_setting -import otaclient_v2_pb2 as v2 -from configs import config as cfg -from ecu import Ecu - -logger = log_setting.get_logger( - __name__, cfg.LOG_LEVEL_TABLE.get(__name__, cfg.DEFAULT_LOG_LEVEL) -) - - -class OtaClientStub: - def __init__(self, ecus: list, terminate=None): - # check if all the names are unique - names = [ecu._name for ecu in ecus] - assert len(names) == len(set(names)) - # check if only one ecu is main - mains = [ecu for ecu in ecus if ecu._is_main] - assert len(mains) == 1 - - self._ecus = ecus - self._main_ecu = mains[0] - self._terminate = terminate - - async def update(self, request: v2.UpdateRequest) -> v2.UpdateResponse: - logger.info(f"{request=}") - response = v2.UpdateResponse() - - for ecu in self._ecus: - entry = OtaClientStub._find_request(request.ecu, ecu._name) - if entry: - logger.info(f"{ecu=}, {entry.version=}") - response_ecu = response.ecu.add() - ecu.update(response_ecu, entry.version) - - logger.info(f"{response=}") - return response - - def rollback(self, request): - logger.info(f"{request=}") - response = v2.RollbackResponse() - - return response - - async def status(self, request: v2.StatusRequest) -> v2.StatusResponse: - logger.info(f"{request=}") - response = v2.StatusResponse() - - for ecu in self._ecus: - response_ecu = response.ecu.add() - ecu.status(response_ecu) - response.available_ecu_ids.extend([ecu._name]) - - logger.debug(f"{response=}") - - if self._sub_ecus_success_and_main_ecu_phase_persistent(response.ecu): - self._main_ecu.change_to_success() - for index, ecu in enumerate(self._ecus): - self._ecus[index] = ecu.create() # create new ecu instances - self._terminate(self._main_ecu._time_to_restart) - - return response - - @staticmethod - def _find_request(update_request, ecu_id): - for request in update_request: - if request.ecu_id == ecu_id: - return request - return None - - def _update(self, ecu, response): - ecu.update(response) - - def _status(self, ecu, response): - ecu.status(response) - - def _sub_ecus_success_and_main_ecu_phase_persistent(self, response_ecu): - for ecu in response_ecu: - if ecu.ecu_id == self._main_ecu._name: - if ( - ecu.status.status != v2.StatusOta.UPDATING - or ecu.status.progress.phase != v2.StatusProgressPhase.PERSISTENT - ): - return False - else: - if ecu.status.status != v2.StatusOta.SUCCESS: - return False - return True diff --git a/tools/emulator/requirements.txt b/tools/emulator/requirements.txt deleted file mode 100644 index 3eadadb31..000000000 --- a/tools/emulator/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -grpcio==1.53.0 -protobuf==3.18.3 -PyYAML>=3.12 diff --git a/tools/offline_ota_image_builder/builder.py b/tools/offline_ota_image_builder/builder.py index bbf31e522..2a8b05b20 100644 --- a/tools/offline_ota_image_builder/builder.py +++ b/tools/offline_ota_image_builder/builder.py @@ -23,8 +23,8 @@ from pathlib import Path from typing import Mapping, Optional, Sequence -from otaclient.app import ota_metadata -from otaclient.app.common import subprocess_call +from ota_metadata.legacy import parser as ota_metadata_parser +from otaclient_common.common import subprocess_call from .configs import cfg from .manifest import ImageMetadata, Manifest @@ -70,9 +70,9 @@ def _process_ota_image(ota_image_dir: StrPath, *, data_dir: StrPath, meta_dir: S ota_image_dir = Path(ota_image_dir) # ------ process OTA image metadata ------ # - metadata_jwt_fpath = ota_image_dir / ota_metadata.OTAMetadata.METADATA_JWT + metadata_jwt_fpath = ota_image_dir / ota_metadata_parser.OTAMetadata.METADATA_JWT # NOTE: we don't need to do certificate verification here, so set certs_dir to empty - metadata_jwt = ota_metadata._MetadataJWTParser( + metadata_jwt = ota_metadata_parser._MetadataJWTParser( metadata_jwt_fpath.read_text(), certs_dir="" ).get_otametadata() @@ -82,7 +82,7 @@ def _process_ota_image(ota_image_dir: StrPath, *, data_dir: StrPath, meta_dir: S # ------ update data_dir with the contents of this OTA image ------ # with open(ota_image_dir / metadata_jwt.regular.file, "r") as f: for line in f: - reg_inf = ota_metadata.parse_regulars_from_txt(line) + reg_inf = ota_metadata_parser.parse_regulars_from_txt(line) ota_file_sha256 = reg_inf.sha256hash.hex() if reg_inf.compressed_alg == cfg.OTA_IMAGE_COMPRESSION_ALG: @@ -108,7 +108,7 @@ def _process_ota_image(ota_image_dir: StrPath, *, data_dir: StrPath, meta_dir: S # ------ update meta_dir with the OTA meta files in this image ------ # # copy OTA metafiles to image specific meta folder - for _metaf in ota_metadata.MetafilesV1: + for _metaf in ota_metadata_parser.MetafilesV1: shutil.move(str(ota_image_dir / _metaf.value), meta_dir) # copy certificate and metadata.jwt shutil.move(str(ota_image_dir / metadata_jwt.certificate.file), meta_dir) diff --git a/tools/status_monitor/ecu_status_box.py b/tools/status_monitor/ecu_status_box.py index 61cc38c1d..4b3a57cb5 100644 --- a/tools/status_monitor/ecu_status_box.py +++ b/tools/status_monitor/ecu_status_box.py @@ -13,13 +13,15 @@ # limitations under the License. +from __future__ import annotations + import curses import datetime import threading import time from typing import Sequence, Tuple -from otaclient.app.proto import wrapper as proto_wrapper +from otaclient_api.v2 import types as api_types from .configs import config from .utils import FormatValue, ScreenHandler @@ -47,7 +49,7 @@ def __init__(self, ecu_id: str, index: int) -> None: # contents for raw ecu status info sub window self.raw_ecu_status_contents = [] - self._last_status = proto_wrapper.StatusResponseEcuV2() + self._last_status = api_types.StatusResponseEcuV2() # prevent conflicts between status update and pad update self._lock = threading.Lock() self.last_updated = 0 @@ -60,9 +62,7 @@ def get_raw_info_contents(self) -> Tuple[Sequence[str], int]: """Getter for raw_ecu_status_contents.""" return self.raw_ecu_status_contents, self.last_updated - def update_ecu_status( - self, ecu_status: proto_wrapper.StatusResponseEcuV2, index: int - ): + def update_ecu_status(self, ecu_status: api_types.StatusResponseEcuV2, index: int): """Update internal contents storage with input . This method is called by tracker module to update the contents within @@ -80,7 +80,7 @@ def update_ecu_status( "-" * (self.DISPLAY_BOX_HCOLS - 2), ] - if ecu_status.ota_status is proto_wrapper.StatusOta.UPDATING: + if ecu_status.ota_status is api_types.StatusOta.UPDATING: update_status = ecu_status.update_status # TODO: render a progress bar according to ECU status V2's specification self.contents.extend( @@ -114,7 +114,7 @@ def update_ecu_status( "No detailed failure information.", ] - elif ecu_status.ota_status is proto_wrapper.StatusOta.FAILURE: + elif ecu_status.ota_status is api_types.StatusOta.FAILURE: self.contents.extend( [ f"ota_status: {ecu_status.ota_status.name}", diff --git a/tools/status_monitor/ecu_status_tracker.py b/tools/status_monitor/ecu_status_tracker.py index 5cebfc200..f4342f504 100644 --- a/tools/status_monitor/ecu_status_tracker.py +++ b/tools/status_monitor/ecu_status_tracker.py @@ -13,13 +13,15 @@ # limitations under the License. +from __future__ import annotations + import asyncio import threading from queue import Queue from typing import Dict, List, Optional -from otaclient.app.ota_client_call import ECUNoResponse, OtaClientCall -from otaclient.app.proto import wrapper as proto_wrapper +from otaclient_api.v2 import types as api_types +from otaclient_api.v2.api_caller import ECUNoResponse, OTAClientCall from .configs import config as cfg from .ecu_status_box import ECUStatusDisplayBox @@ -36,8 +38,11 @@ async def status_polling_thread( ): while not stop_event.is_set(): try: - resp = await OtaClientCall.status_call( - ecu_id, host, port, request=proto_wrapper.StatusRequest() + resp = await OTAClientCall.status_call( + ecu_id, + host, + port, + request=api_types.StatusRequest(), ) que.put_nowait(resp) except ECUNoResponse: @@ -84,7 +89,7 @@ def _polling_thread(): def _update_thread(): while not self._stop_event.is_set(): - _ecu_status: proto_wrapper.StatusResponse = self._que.get() + _ecu_status: api_types.StatusResponse = self._que.get() if _ecu_status is self._END_SENTINEL: return diff --git a/tools/test_utils/README.md b/tools/test_utils/README.md deleted file mode 100644 index 8384ef764..000000000 --- a/tools/test_utils/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Test utils for debugging otaclient - -This test utils set provides lib to directly query otaclient `update/status/rollback` API, and a tool to simulate dummy multi-ecu setups. - -## Usage guide for setting up test environemnt - -This test_utils can be used to setup a test environment consists of a real otaclient(either on VM or on actually ECU) as main ECU, -and setup many dummy subECUs that can receive update request and return the expected status report. - -### 1. Install the otaclient's dependencies - -`test_utils` depends on otaclient, so you need to install at least the dependencies of otaclient. -Please refer to [docs/INSTALLATION.md](docs/INSTALLATION.md). - -### 2. Update the `ecu_info.yaml` and `update_request.yaml` accordingly - -Update the `ecu_info.yaml` under the `test_utils` folder as your test environment design, -the example `ecu_info.yaml` consists of one mainECU `autoware`(expecting to be a real otaclient), -and 2 dummy subECUs which will be prepared at step 2. - -Update the `update_request.yaml` under the `test_utils` folder as your test environment setup. -This file contains the update request to be sent. - -### 3. Launch `setup_ecu.py` to setup dummy subECUs, and launch the real otaclient - -Setup subECUs: - -```python -# with venv, under the tools/ folder -python3 -m test_utils.setup_ecu subecus -``` - -And then launch the real otaclient, be sure that the otaclient is reachable to the machine -that running the test_utils. - -### 4. Send an update request to main ECU - -For example, we have `autoware` ECU as main ECU, then - -```python -# with venv, under the tools/ folder -python3 -m test_utils.api_caller update -t autoware -``` diff --git a/tools/test_utils/__init__.py b/tools/test_utils/__init__.py deleted file mode 100644 index bcfd866ad..000000000 --- a/tools/test_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tools/test_utils/_logutil.py b/tools/test_utils/_logutil.py deleted file mode 100644 index a26aabd64..000000000 --- a/tools/test_utils/_logutil.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -_log_format = ( - "[%(asctime)s][%(levelname)s]-%(filename)s:%(funcName)s:%(lineno)d,%(message)s" -) -logging.basicConfig(format=_log_format) - - -def get_logger(name: str, level: int = logging.DEBUG) -> logging.Logger: - logger = logging.getLogger(name) - logger.setLevel(level) - return logger diff --git a/tools/test_utils/_update_call.py b/tools/test_utils/_update_call.py deleted file mode 100644 index b3c45fcbb..000000000 --- a/tools/test_utils/_update_call.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import yaml - -from otaclient.app.ota_client_call import ECUNoResponse, OtaClientCall -from otaclient.app.proto import wrapper - -from . import _logutil - -logger = _logutil.get_logger(__name__) - - -def load_external_update_request(request_yaml_file: str) -> wrapper.UpdateRequest: - with open(request_yaml_file, "r") as f: - try: - request_yaml = yaml.safe_load(f) - assert isinstance(request_yaml, list), "expect update request to be a list" - except Exception as e: - logger.exception(f"invalid update request yaml: {e!r}") - raise - - logger.info(f"load external request: {request_yaml!r}") - request = wrapper.UpdateRequest() - for request_ecu in request_yaml: - request.ecu.append(wrapper.UpdateRequestEcu(**request_ecu)) - return request - - -async def call_update( - ecu_id: str, - ecu_ip: str, - ecu_port: int, - *, - request_file: str, -): - logger.debug(f"request update on ecu(@{ecu_id}) at {ecu_ip}:{ecu_port}") - update_request = load_external_update_request(request_file) - - try: - update_response = await OtaClientCall.update_call( - ecu_id, ecu_ip, ecu_port, request=update_request - ) - logger.info(f"{update_response.export_pb()=}") - except ECUNoResponse as e: - logger.exception(f"update request failed: {e!r}") diff --git a/tools/test_utils/api_caller.py b/tools/test_utils/api_caller.py deleted file mode 100644 index 529e9cd54..000000000 --- a/tools/test_utils/api_caller.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import asyncio -import sys -from pathlib import Path - -import yaml - -try: - import otaclient # noqa: F401 -except ImportError: - sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from . import _logutil, _update_call - -logger = _logutil.get_logger(__name__) - - -async def main(args: argparse.Namespace): - with open(args.ecu_info, "r") as f: - ecu_info = yaml.safe_load(f) - assert isinstance(ecu_info, dict) - - target_ecu_id = args.target - # by default, send request to main ECU - try: - ecu_id = ecu_info["ecu_id"] - ecu_ip = ecu_info["ip_addr"] - ecu_port = 50051 - except KeyError: - raise ValueError(f"invalid ecu_info: {ecu_info=}") - - if target_ecu_id != ecu_info.get("ecu_id"): - found = False - # search for target by ecu_id - for subecu in ecu_info.get("secondaries", []): - try: - if subecu["ecu_id"] == target_ecu_id: - ecu_id = subecu["ecu_id"] - ecu_ip = subecu["ip_addr"] - ecu_port = int(subecu.get("port", 50051)) - found = True - break - except KeyError: - continue - - if not found: - logger.critical(f"target ecu {target_ecu_id} is not found") - sys.exit(-1) - - logger.info(f"send request to target ecu: {ecu_id=}, {ecu_ip=}") - cmd = args.command - if cmd == "update": - await _update_call.call_update( - ecu_id, - ecu_ip, - ecu_port, - request_file=args.request, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="calling ECU's API", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-c", - "--ecu_info", - type=str, - default="test_utils/ecu_info.yaml", - help="ecu_info file to configure the caller", - ) - parser.add_argument("command", help="API to call, available API: update") - parser.add_argument( - "-t", - "--target", - default="autoware", - help="indicate the target for the API request", - ) - parser.add_argument( - "-r", - "--request", - default="test_utils/update_request.yaml", - help="(update) yaml file that contains the request to send", - ) - - args = parser.parse_args() - if args.command != "update": - parser.error(f"unknown API: {args.command} (available: update)") - if not Path(args.ecu_info).is_file(): - parser.error(f"ecu_info file {args.ecu_info} not found!") - if args.command == "update" and not Path(args.request).is_file(): - parser.error(f"update request file {args.request} not found!") - - asyncio.run(main(args)) diff --git a/tools/test_utils/ecu_info.yaml b/tools/test_utils/ecu_info.yaml deleted file mode 100644 index 2c8c7c8e4..000000000 --- a/tools/test_utils/ecu_info.yaml +++ /dev/null @@ -1,9 +0,0 @@ -# sample ecu_info.yaml, with 2 perception ECUs -format_version: 1 -ecu_id: "autoware" -ip_addr: "10.0.0.2" -secondaries: - - ecu_id: "p1" - ip_addr: "10.0.0.11" - - ecu_id: "p2" - ip_addr: "10.0.0.12" diff --git a/tools/test_utils/setup_ecu.py b/tools/test_utils/setup_ecu.py deleted file mode 100644 index 4190262a6..000000000 --- a/tools/test_utils/setup_ecu.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import asyncio -import sys -from pathlib import Path -from typing import List - -import grpc -import yaml - -try: - import otaclient # noqa: F401 -except ImportError: - sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from otaclient.app.ota_client_service import service_wait_for_termination -from otaclient.app.proto import v2, v2_grpc - -from . import _logutil - -logger = _logutil.get_logger(__name__) - -_DEFAULT_PORT = 50051 -_MODE = {"standalone", "mainecu", "subecus"} - - -class MiniOtaClientServiceV2(v2_grpc.OtaClientServiceServicer): - UPDATE_TIME_COST = 10 - REBOOT_INTERVAL = 5 - - def __init__(self, ecu_id: str): - self.ecu_id = ecu_id - self._lock = asyncio.Lock() - self._in_update = asyncio.Event() - self._rebooting = asyncio.Event() - - async def _on_update(self): - await asyncio.sleep(self.UPDATE_TIME_COST) - logger.debug(f"{self.ecu_id=} finished update, rebooting...") - self._rebooting.set() - await asyncio.sleep(self.REBOOT_INTERVAL) - self._rebooting.clear() - self._in_update.clear() - - async def Update(self, request: v2.UpdateRequest, context: grpc.ServicerContext): - peer = context.peer() - logger.debug(f"{self.ecu_id}: update request from {peer=}") - logger.debug(f"{request=}") - - # return if not listed as target - found = False - for ecu in request.ecu: - if ecu.ecu_id == self.ecu_id: - found = True - break - if not found: - logger.debug(f"{self.ecu_id}, Update: not listed as update target, abort") - return v2.UpdateResponse() - - results = v2.UpdateResponse() - if self._in_update.is_set(): - resp_ecu = v2.UpdateResponseEcu( - ecu_id=self.ecu_id, - result=v2.RECOVERABLE, - ) - results.ecu.append(resp_ecu) - else: - logger.debug("start update") - self._in_update.set() - asyncio.create_task(self._on_update()) - - return results - - async def Status(self, _, context: grpc.ServicerContext): - peer = context.peer() - logger.debug(f"{self.ecu_id}: status request from {peer=}") - if self._rebooting.is_set(): - return v2.StatusResponse() - - result_ecu = v2.StatusResponseEcu( - ecu_id=self.ecu_id, - result=v2.NO_FAILURE, - ) - - ecu_status = result_ecu.status - if self._in_update.is_set(): - ecu_status.status = v2.UPDATING - else: - ecu_status.status = v2.SUCCESS - - result = v2.StatusResponse() - result.ecu.append(result_ecu) - return result - - -async def launch_otaclient(ecu_id, ecu_ip, ecu_port): - server = grpc.aio.server() - service = MiniOtaClientServiceV2(ecu_id) - v2_grpc.add_OtaClientServiceServicer_to_server(service, server) - - server.add_insecure_port(f"{ecu_ip}:{ecu_port}") - await server.start() - await service_wait_for_termination(server) - - -async def mainecu_mode(ecu_info_file: str): - ecu_info = yaml.safe_load(Path(ecu_info_file).read_text()) - ecu_id = ecu_info["ecu_id"] - ecu_ip = ecu_info["ip_addr"] - ecu_port = int(ecu_info.get("port", _DEFAULT_PORT)) - - logger.info(f"start {ecu_id=} at {ecu_ip}:{ecu_port}") - await launch_otaclient(ecu_id, ecu_ip, ecu_port) - - -async def subecu_mode(ecu_info_file: str): - ecu_info = yaml.safe_load(Path(ecu_info_file).read_text()) - - # schedule the servers to the thread pool - tasks: List[asyncio.Task] = [] - for subecu in ecu_info["secondaries"]: - ecu_id = subecu["ecu_id"] - ecu_ip = subecu["ip_addr"] - ecu_port = int(subecu.get("port", _DEFAULT_PORT)) - logger.info(f"start {ecu_id=} at {ecu_ip}:{ecu_port}") - tasks.append(asyncio.create_task(launch_otaclient(ecu_id, ecu_ip, ecu_port))) - - await asyncio.gather(*tasks) - - -async def standalone_mode(args: argparse.Namespace): - await launch_otaclient("standalone", args.ip, args.port) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="calling main ECU's API", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-c", "--ecu_info", default="test_utils/ecu_info.yaml", help="ecu_info" - ) - parser.add_argument( - "mode", - default="standalone", - help=( - "running mode for mini_ota_client(standalone, subecus, mainecu)\n" - "\tstandalone: run a single mini_ota_client\n" - "\tmainecu: run a single mini_ota_client as mainecu according to ecu_info.yaml\n" - "\tsubecus: run subecu(s) according to ecu_info.yaml" - ), - ) - parser.add_argument( - "--ip", - default="127.0.0.1", - help="(standalone) listen at IP", - ) - parser.add_argument( - "--port", - default=_DEFAULT_PORT, - help="(standalone) use port PORT", - ) - - args = parser.parse_args() - - if args.mode not in _MODE: - parser.error(f"invalid mode {args.mode}, should be one of {_MODE}") - if args.mode != "standalone" and not Path(args.ecu_info).is_file(): - parser.error( - f"invalid ecu_info_file {args.ecu_info!r}. ecu_info.yaml is required for non-standalone mode" - ) - - if args.mode == "subecus": - logger.info("subecus mode") - coro = subecu_mode(args.ecu_info) - elif args.mode == "mainecu": - logger.info("mainecu mode") - coro = mainecu_mode(args.ecu_info) - else: - logger.info("standalone mode") - coro = standalone_mode(args) - - asyncio.run(coro) diff --git a/tools/test_utils/update_request.yaml b/tools/test_utils/update_request.yaml deleted file mode 100644 index 02e612a9e..000000000 --- a/tools/test_utils/update_request.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# sample update request -- ecu_id: "autoware" - version: "789.x" - url: "http://10.0.0.1:8080" - cookies: '{"test": "my-cookie"}' -- ecu_id: "p1" - version: "789.x" - url: "http://10.0.0.1:8080" - cookies: '{"test": "my-cookie"}' -- ecu_id: "p2" - version: "789.x" - url: "http://10.0.0.1:8080" - cookies: '{"test": "my-cookie"}'