From e561d585710e4c7b774a906491ec41c0224103e3 Mon Sep 17 00:00:00 2001 From: Bodong Yang <86948717+Bodong-Yang@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:10:22 +0900 Subject: [PATCH] feat: split and run grpc_server, ota_core and otaclient main as separated processes (#431) This PR introduces runtime re-architecture of otaclient, now at runtime each core otaclient components run as standalone processes, utilizing IPC to work together as a whole otaclient: 1. main process(by main module): the daemon process that first sets up the queue and shared_memory for each components to work together, and then brings up each component processes one by one. During otaclient running life-cycle, it will monitor each components status(active or exited) and handle signals(SIGTERM, SIGINT) accordingly. 2. ota_core process (by ota_core module): the process running the core implementation of OTA. It handles the OTA request and actually do the OTA, while reports its status via shared_memory all the time. 3. grpc_server process (by otaclient.grpc.api_v2 package): the prcoess running OTA Service API grpc server, it handles the OTA API requests, and translates it into otaclient internal form and dispatches to ota_core process via IPC. 4. ota_proxy process (by ota_proxy package): the process running ota_proxy server, will only be brought up when there is active OTA within the cluster. A simple otaclient internal IPC interface is implemented using simple types defined in otaclient._types with queue between API grpc server and ota_core. grpc_server dispatches OTA requests down to ota_core(with op_queue) and get the response to the request from ota_core(with ack_queue). A cros-process status report mechanism is implemented based on sharing hmac-protected(with one-time preshared key) pickled object with shared_memory. ota_core uses this mechanism to write its latest status into the shm, and grpc_server process read this shm to track the ota_core's internal status, and translate it into the OTA grpc API format. --- src/otaclient/_otaproxy_ctx.py | 140 ++++++++ src/otaclient/_status_monitor.py | 89 +++-- src/otaclient/_types.py | 38 ++- src/otaclient/{utils.py => _utils.py} | 57 +++- src/otaclient/errors.py | 2 +- src/otaclient/grpc/_otaproxy_ctx.py | 145 -------- src/otaclient/grpc/api_v2/ecu_status.py | 85 ++--- src/otaclient/grpc/api_v2/ecu_tracker.py | 66 +++- src/otaclient/grpc/api_v2/main.py | 87 +++++ src/otaclient/grpc/api_v2/servicer.py | 258 +++++++-------- src/otaclient/grpc/api_v2/types.py | 12 +- src/otaclient/main.py | 194 +++++++---- src/otaclient/ota_core.py | 310 ++++++++++++------ tests/conftest.py | 17 +- tests/test_otaclient/test_create_standby.py | 39 +-- .../test_grpc/test_api_v2/test_ecu_status.py | 51 ++- .../test_grpc/test_api_v2/test_servicer.py | 298 ----------------- tests/test_otaclient/test_main.py | 68 ---- tests/test_otaclient/test_ota_core.py | 70 ++-- tests/test_otaclient/test_status_monitor.py | 54 +-- tests/test_otaclient/test_utils.py | 5 +- 21 files changed, 1059 insertions(+), 1026 deletions(-) create mode 100644 src/otaclient/_otaproxy_ctx.py rename src/otaclient/{utils.py => _utils.py} (65%) delete mode 100644 src/otaclient/grpc/_otaproxy_ctx.py create mode 100644 src/otaclient/grpc/api_v2/main.py delete mode 100644 tests/test_otaclient/test_grpc/test_api_v2/test_servicer.py delete mode 100644 tests/test_otaclient/test_main.py diff --git a/src/otaclient/_otaproxy_ctx.py b/src/otaclient/_otaproxy_ctx.py new file mode 100644 index 000000000..e08a50638 --- /dev/null +++ b/src/otaclient/_otaproxy_ctx.py @@ -0,0 +1,140 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Control of the otaproxy server startup/shutdown. + +The API exposed by this module is meant to be controlled by otaproxy managing thread only. +""" + + +from __future__ import annotations + +import asyncio +import atexit +import logging +import multiprocessing as mp +import multiprocessing.context as mp_ctx +import shutil +import time +from functools import partial +from pathlib import Path + +from ota_proxy import config as local_otaproxy_cfg +from ota_proxy import run_otaproxy +from ota_proxy.config import config as otaproxy_cfg +from otaclient._types import MultipleECUStatusFlags +from otaclient.configs.cfg import cfg, proxy_info +from otaclient_common.common import ensure_otaproxy_start + +logger = logging.getLogger(__name__) + +_otaproxy_p: mp_ctx.SpawnProcess | None = None +_global_shutdown: bool = False + + +def shutdown_otaproxy_server() -> None: + global _otaproxy_p, _global_shutdown + _global_shutdown = True + if _otaproxy_p: + _otaproxy_p.terminate() + _otaproxy_p.join() + _otaproxy_p = None + + +OTAPROXY_CHECK_INTERVAL = 3 +OTAPROXY_MIN_STARTUP_TIME = 120 +"""Keep otaproxy running at least 60 seconds after startup.""" +OTA_CACHE_DIR_CHECK_INTERVAL = 60 + + +def otaproxy_process(*, init_cache: bool) -> None: + from otaclient._logging import configure_logging + + configure_logging() + logger.info("otaproxy process started") + + external_cache_mnt_point = None + if cfg.OTAPROXY_ENABLE_EXTERNAL_CACHE: + external_cache_mnt_point = cfg.EXTERNAL_CACHE_DEV_MOUNTPOINT + + host, port = ( + str(proxy_info.local_ota_proxy_listen_addr), + proxy_info.local_ota_proxy_listen_port, + ) + + upper_proxy = str(proxy_info.upper_ota_proxy or "") + logger.info(f"will launch otaproxy at http://{host}:{port}, with {upper_proxy=}") + if upper_proxy: + logger.info(f"wait for {upper_proxy=} online...") + ensure_otaproxy_start(str(upper_proxy)) + + asyncio.run( + run_otaproxy( + host=host, + port=port, + init_cache=init_cache, + cache_dir=local_otaproxy_cfg.BASE_DIR, + cache_db_f=local_otaproxy_cfg.DB_FILE, + upper_proxy=upper_proxy, + enable_cache=proxy_info.enable_local_ota_proxy_cache, + enable_https=proxy_info.gateway_otaproxy, + external_cache_mnt_point=external_cache_mnt_point, + ) + ) + + +def otaproxy_control_thread( + ecu_status_flags: MultipleECUStatusFlags, +) -> None: # pragma: no cover + atexit.register(shutdown_otaproxy_server) + + _mp_ctx = mp.get_context("spawn") + + ota_cache_dir = Path(otaproxy_cfg.BASE_DIR) + next_ota_cache_dir_checkpoint = 0 + + global _otaproxy_p + while not _global_shutdown: + time.sleep(OTAPROXY_CHECK_INTERVAL) + _now = time.time() + + _otaproxy_running = _otaproxy_p and _otaproxy_p.is_alive() + _otaproxy_should_run = ecu_status_flags.any_requires_network.is_set() + _all_success = ecu_status_flags.all_success.is_set() + + if not _otaproxy_should_run and not _otaproxy_running: + if ( + _now > next_ota_cache_dir_checkpoint + and _all_success + and ota_cache_dir.is_dir() + ): + logger.info( + "all tracked ECUs are in SUCCESS OTA status, cleanup ota cache dir ..." + ) + next_ota_cache_dir_checkpoint = _now + OTA_CACHE_DIR_CHECK_INTERVAL + shutil.rmtree(ota_cache_dir, ignore_errors=True) + + elif _otaproxy_should_run and not _otaproxy_running: + # NOTE: always try to re-use cache. If the cache dir is empty, otaproxy + # will still init the cache even init_cache is False. + _otaproxy_p = _mp_ctx.Process( + target=partial(otaproxy_process, init_cache=False), + name="otaproxy", + ) + _otaproxy_p.start() + next_ota_cache_dir_checkpoint = _now + OTAPROXY_MIN_STARTUP_TIME + time.sleep(OTAPROXY_MIN_STARTUP_TIME) # prevent pre-mature shutdown + + elif _otaproxy_p and _otaproxy_running and not _otaproxy_should_run: + logger.info("shutting down otaproxy as not needed now ...") + shutdown_otaproxy_server() diff --git a/src/otaclient/_status_monitor.py b/src/otaclient/_status_monitor.py index f6686b6cd..9db09cfe0 100644 --- a/src/otaclient/_status_monitor.py +++ b/src/otaclient/_status_monitor.py @@ -23,7 +23,7 @@ from dataclasses import asdict, dataclass from enum import Enum, auto from threading import Thread -from typing import Union, cast +from typing import Literal, Union, cast from otaclient._types import ( FailureType, @@ -34,17 +34,25 @@ UpdateProgress, UpdateTiming, ) +from otaclient._utils import SharedOTAClientStatusWriter +from otaclient_common.logging import BurstSuppressFilter logger = logging.getLogger(__name__) +burst_suppressed_logger = logging.getLogger(f"{__name__}.shm_push") +# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds +burst_suppressed_logger.addFilter( + BurstSuppressFilter( + f"{__name__}.shm_push", + upper_logger_name=__name__, + burst_round_length=30, + burst_max=6, + ) +) -_otaclient_shutdown = False _status_report_queue: queue.Queue | None = None def _global_shutdown(): - global _otaclient_shutdown - _otaclient_shutdown = True - if _status_report_queue: _status_report_queue.put_nowait(TERMINATE_SENTINEL) @@ -120,7 +128,7 @@ class StatusReport: # def _on_session_finished( status_storage: OTAClientStatus, payload: OTAStatusChangeReport -): +) -> Literal[True]: status_storage.session_id = "" status_storage.update_phase = UpdatePhase.INITIALIZING status_storage.update_meta = UpdateMeta() @@ -137,10 +145,12 @@ def _on_session_finished( status_storage.failure_reason = "" status_storage.failure_traceback = "" + return True + def _on_new_ota_session( status_storage: OTAClientStatus, payload: OTAStatusChangeReport -): +) -> Literal[True]: status_storage.ota_status = payload.new_ota_status status_storage.update_phase = UpdatePhase.INITIALIZING status_storage.update_meta = UpdateMeta() @@ -149,6 +159,8 @@ def _on_new_ota_session( status_storage.failure_type = FailureType.NO_FAILURE status_storage.failure_reason = "" + return True + def _on_update_phase_changed( status_storage: OTAClientStatus, payload: OTAUpdatePhaseChangeReport @@ -157,7 +169,7 @@ def _on_update_phase_changed( logger.warning( "attempt to update update_timing when no OTA update session on-going" ) - return + return False phase, trigger_timestamp = payload.new_update_phase, payload.trigger_timestamp if phase == UpdatePhase.PROCESSING_POSTUPDATE: @@ -170,14 +182,17 @@ def _on_update_phase_changed( update_timing.update_apply_start_timestamp = trigger_timestamp status_storage.update_phase = phase + return True -def _on_update_progress(status_storage: OTAClientStatus, payload: UpdateProgressReport): +def _on_update_progress( + status_storage: OTAClientStatus, payload: UpdateProgressReport +) -> bool: if (update_progress := status_storage.update_progress) is None: logger.warning( "attempt to update update_progress when no OTA update session on-going" ) - return + return False op = payload.operation if ( @@ -195,6 +210,7 @@ def _on_update_progress(status_storage: OTAClientStatus, payload: UpdateProgress update_progress.downloading_errors += payload.errors elif op == UpdateProgressReport.Type.APPLY_REMOVE_DELTA: update_progress.removed_files_num += payload.processed_file_num + return True def _on_update_meta(status_storage: OTAClientStatus, payload: SetUpdateMetaReport): @@ -204,7 +220,7 @@ def _on_update_meta(status_storage: OTAClientStatus, payload: SetUpdateMetaRepor logger.warning( "attempt to update update_meta when no OTA update session on-going" ) - return + return False _input = asdict(payload) for k, v in _input.items(): @@ -213,31 +229,45 @@ def _on_update_meta(status_storage: OTAClientStatus, payload: SetUpdateMetaRepor continue if v: setattr(update_meta, k, v) + return True # # ------ status monitor implementation ------ # # +# A sentinel object to tell the thread stop TERMINATE_SENTINEL = cast(StatusReport, object()) +MIN_COLLECT_INTERVAL = 0.5 # seconds +SHM_PUSH_INTERVAL = 0.5 # seconds class OTAClientStatusCollector: + """NOTE: status_monitor will only be started once during whole otaclient lifecycle!""" def __init__( self, msg_queue: queue.Queue[StatusReport], + shm_status: SharedOTAClientStatusWriter, *, - min_collect_interval: int = 1, - min_push_interval: int = 1, + min_collect_interval: float = MIN_COLLECT_INTERVAL, + shm_push_interval: float = SHM_PUSH_INTERVAL, + max_traceback_size: int, ) -> None: + self.max_traceback_size = max_traceback_size self.min_collect_interval = min_collect_interval - self.min_push_interval = min_push_interval + self.shm_push_interval = shm_push_interval self._input_queue = msg_queue + global _status_report_queue + _status_report_queue = msg_queue + self._status = None + self._shm_status = shm_status + + atexit.register(shm_status.atexit) - def load_report(self, report: StatusReport): + def load_report(self, report: StatusReport) -> bool: if self._status is None: self._status = OTAClientStatus() status_storage = self._status @@ -246,37 +276,56 @@ def load_report(self, report: StatusReport): # ------ update otaclient meta ------ # if isinstance(payload, SetOTAClientMetaReport): status_storage.firmware_version = payload.firmware_version + return True # ------ on session start/end ------ # if isinstance(payload, OTAStatusChangeReport): + if (_traceback := payload.failure_traceback) and len( + _traceback + ) > self.max_traceback_size: + payload.failure_traceback = _traceback[-self.max_traceback_size :] + new_ota_status = payload.new_ota_status if new_ota_status in [OTAStatus.UPDATING, OTAStatus.ROLLBACKING]: status_storage.session_id = report.session_id return _on_new_ota_session(status_storage, payload) - status_storage.session_id = "" # clear session if we are not in an OTA return _on_session_finished(status_storage, payload) # ------ during OTA session ------ # report_session_id = report.session_id if report_session_id != status_storage.session_id: - logger.warning(f"drop reports from mismatched session: {report}") - return # drop invalid report + logger.warning( + f"drop reports from mismatched session (expect {status_storage.session_id=}): {report}" + ) + return False if isinstance(payload, OTAUpdatePhaseChangeReport): return _on_update_phase_changed(status_storage, payload) if isinstance(payload, UpdateProgressReport): return _on_update_progress(status_storage, payload) if isinstance(payload, SetUpdateMetaReport): return _on_update_meta(status_storage, payload) + return False def _status_collector_thread(self) -> None: """Main entry of status monitor working thread.""" - while not _otaclient_shutdown: + _next_shm_push = 0 + while True: + _now = time.time() try: report = self._input_queue.get_nowait() if report is TERMINATE_SENTINEL: break - self.load_report(report) + + # ------ push status on load_report ------ # + if self.load_report(report) and self._status and _now > _next_shm_push: + try: + self._shm_status.write_msg(self._status) + _next_shm_push = _now + self.shm_push_interval + except Exception as e: + burst_suppressed_logger.debug( + f"failed to push status to shm: {e!r}" + ) except queue.Empty: time.sleep(self.min_collect_interval) diff --git a/src/otaclient/_types.py b/src/otaclient/_types.py index 2f54a954b..92981290d 100644 --- a/src/otaclient/_types.py +++ b/src/otaclient/_types.py @@ -16,6 +16,7 @@ from __future__ import annotations +import multiprocessing.synchronize as mp_sync from dataclasses import dataclass from typing import ClassVar, Optional @@ -124,7 +125,39 @@ class OTAClientStatus: @dataclass -class UpdateRequestV2: +class MultipleECUStatusFlags: + any_in_update: mp_sync.Event + any_requires_network: mp_sync.Event + all_success: mp_sync.Event + + +# +# ------ OTA requests IPC ------ # +# + + +class IPCResEnum(StrEnum): + ACCEPT = "ACCEPT" + REJECT_BUSY = "REJECT_BUSY" + """The request has been rejected due to otaclient is busy.""" + REJECT_OTHER = "REJECT_OTHER" + """The request has been rejected for other reason.""" + + +@dataclass +class IPCResponse: + res: IPCResEnum + session_id: str + msg: str = "" + + +@dataclass +class IPCRequest: + session_id: str + + +@dataclass +class UpdateRequestV2(IPCRequest): """Compatible with OTA API version 2.""" version: str @@ -132,5 +165,6 @@ class UpdateRequestV2: cookies_json: str -class RollbackRequestV2: +@dataclass +class RollbackRequestV2(IPCRequest): """Compatbile with OTA API version 2.""" diff --git a/src/otaclient/utils.py b/src/otaclient/_utils.py similarity index 65% rename from src/otaclient/utils.py rename to src/otaclient/_utils.py index e98a3aabd..9fffe2c21 100644 --- a/src/otaclient/utils.py +++ b/src/otaclient/_utils.py @@ -22,26 +22,22 @@ import sys import time import traceback -from abc import abstractmethod from pathlib import Path -from typing import Callable, Protocol +from typing import Callable, Literal +from otaclient._types import OTAClientStatus from otaclient_common._io import read_str_from_file, write_str_to_file_atomic +from otaclient_common.shm_status import MPSharedStatusReader, MPSharedStatusWriter from otaclient_common.typing import StrOrPath logger = logging.getLogger(__name__) -class CheckableFlag(Protocol): - - @abstractmethod - def is_set(self) -> bool: ... - - def wait_and_log( - flag: CheckableFlag, + check_flag: Callable[[], bool], message: str = "", *, + check_for: Literal[True] | Literal[False] = True, check_interval: int = 2, log_interval: int = 30, log_func: Callable[[str], None] = logger.info, @@ -49,7 +45,7 @@ def wait_and_log( """Wait for until it is set while print a log every .""" log_round = 0 for seconds in itertools.count(step=check_interval): - if flag.is_set(): + if check_flag() == check_for: return _new_log_round = seconds // log_interval @@ -59,24 +55,24 @@ def wait_and_log( time.sleep(check_interval) -def check_other_otaclient(pid_fpath: StrOrPath) -> None: +def check_other_otaclient(pid_fpath: StrOrPath) -> None: # pragma: no cover """Check if there is another otaclient instance running, and then - create a pid lock file for this otaclient instance.""" - pid_fpath = Path(pid_fpath) + create a pid lock file for this otaclient instance. + NOTE that otaclient should not run inside a PID namespace. + """ + pid_fpath = Path(pid_fpath) if pid := read_str_from_file(pid_fpath, _default=""): # running process will have a folder under /proc if Path(f"/proc/{pid}").is_dir(): logger.error(f"another instance of ota-client({pid=}) is running, abort") sys.exit() - logger.warning(f"dangling otaclient lock file({pid=}) detected, cleanup") - Path(pid_fpath).unlink(missing_ok=True) - + pid_fpath.unlink(missing_ok=True) write_str_to_file_atomic(pid_fpath, f"{os.getpid()}") -def create_otaclient_rundir(run_dir: StrOrPath = "/run/otaclient"): +def create_otaclient_rundir(run_dir: StrOrPath = "/run/otaclient") -> None: """Create the otaclient runtime working dir. TODO: make a helper class for managing otaclient runtime dir. @@ -85,6 +81,31 @@ def create_otaclient_rundir(run_dir: StrOrPath = "/run/otaclient"): run_dir.mkdir(exist_ok=True, parents=True) -def get_traceback(exc: Exception, *, splitter: str = "\n") -> str: +def get_traceback(exc: Exception, *, splitter: str = "\n") -> str: # pragma: no cover """Format the traceback as string.""" return splitter.join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + +class SharedOTAClientStatusWriter(MPSharedStatusWriter[OTAClientStatus]): + """Util for writing OTAClientStatus to shm.""" + + +class SharedOTAClientStatusReader(MPSharedStatusReader[OTAClientStatus]): + """Util for reading OTAClientStatus from shm.""" + + +SESSION_RANDOM_LEN = 4 # bytes, the corresponding hex string will be 8 chars + + +def gen_session_id( + update_version: str, *, random_bytes_num: int = SESSION_RANDOM_LEN +) -> str: # pragma: no cover + """Generate a unique session_id for the new OTA session. + + token schema: + --<4bytes_hex> + """ + _time_factor = str(int(time.time())) + _random_factor = os.urandom(random_bytes_num).hex() + + return f"{update_version}-{_time_factor}-{_random_factor}" diff --git a/src/otaclient/errors.py b/src/otaclient/errors.py index 2aa7cc791..52b5ce23e 100644 --- a/src/otaclient/errors.py +++ b/src/otaclient/errors.py @@ -20,7 +20,7 @@ from typing import ClassVar from otaclient._types import FailureType -from otaclient.utils import get_traceback +from otaclient._utils import get_traceback @unique diff --git a/src/otaclient/grpc/_otaproxy_ctx.py b/src/otaclient/grpc/_otaproxy_ctx.py deleted file mode 100644 index fab3136bf..000000000 --- a/src/otaclient/grpc/_otaproxy_ctx.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Control of the launch/shutdown of otaproxy according to sub ECUs' status.""" - - -from __future__ import annotations - -import asyncio -import logging -import multiprocessing as mp -import multiprocessing.context as mp_ctx -import shutil -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from pathlib import Path -from typing import Optional - -from ota_proxy import config as local_otaproxy_cfg -from ota_proxy import run_otaproxy -from otaclient.configs.cfg import cfg, proxy_info -from otaclient_common.common import ensure_otaproxy_start - -logger = logging.getLogger(__name__) - - -def otaproxy_process(init_cache: bool) -> None: - from otaclient._logging import configure_logging - - configure_logging() - logger.info("otaproxy process started") - - external_cache_mnt_point = None - if cfg.OTAPROXY_ENABLE_EXTERNAL_CACHE: - external_cache_mnt_point = cfg.EXTERNAL_CACHE_DEV_MOUNTPOINT - - host, port = ( - str(proxy_info.local_ota_proxy_listen_addr), - proxy_info.local_ota_proxy_listen_port, - ) - - upper_proxy = str(proxy_info.upper_ota_proxy or "") - logger.info(f"will launch otaproxy at http://{host}:{port}, with {upper_proxy=}") - if upper_proxy: - logger.info(f"wait for {upper_proxy=} online...") - ensure_otaproxy_start(str(upper_proxy)) - - asyncio.run( - run_otaproxy( - host=host, - port=port, - init_cache=init_cache, - cache_dir=local_otaproxy_cfg.BASE_DIR, - cache_db_f=local_otaproxy_cfg.DB_FILE, - upper_proxy=upper_proxy, - enable_cache=proxy_info.enable_local_ota_proxy_cache, - enable_https=proxy_info.gateway_otaproxy, - external_cache_mnt_point=external_cache_mnt_point, - ) - ) - - -class OTAProxyLauncher: - """Launcher of start/stop otaproxy in subprocess.""" - - def __init__(self, *, executor: ThreadPoolExecutor) -> None: - self.enabled = proxy_info.enable_local_ota_proxy - self._lock = asyncio.Lock() - # process start/shutdown will be dispatched to thread pool - self._run_in_executor = partial( - asyncio.get_event_loop().run_in_executor, executor - ) - self._otaproxy_subprocess: mp_ctx.SpawnProcess | None = None - - @property - def is_running(self) -> bool: - return ( - self.enabled - and self._otaproxy_subprocess is not None - and self._otaproxy_subprocess.is_alive() - ) - - # API - - def cleanup_cache_dir(self) -> None: - """ - NOTE: this method should only be called when all ECUs in the cluster - are in SUCCESS ota_status(overall_ecu_status.all_success==True). - """ - if (cache_dir := Path(local_otaproxy_cfg.BASE_DIR)).is_dir(): - logger.info("cleanup ota_cache on success") - shutil.rmtree(cache_dir, ignore_errors=True) - - async def start(self, *, init_cache: bool) -> Optional[int]: - """Start the otaproxy in a subprocess.""" - if not self.enabled or self._lock.locked() or self.is_running: - return - - _spawn_ctx = mp.get_context("spawn") - - async with self._lock: - otaproxy_subprocess = _spawn_ctx.Process( - target=otaproxy_process, - args=(init_cache,), - daemon=True, - name="otaproxy", - ) - self._otaproxy_subprocess = otaproxy_subprocess - await self._run_in_executor(otaproxy_subprocess.start) - logger.info( - f"otaproxy({otaproxy_subprocess.pid=}) started at " - f"{proxy_info.local_ota_proxy_listen_addr}:{proxy_info.local_ota_proxy_listen_port}" - ) - return otaproxy_subprocess.pid - - async def stop(self) -> None: - """Stop the otaproxy subprocess. - - NOTE: This method only shutdown the otaproxy process, it will not cleanup the - cache dir. cache dir cleanup is handled by other mechanism. - Check cleanup_cache_dir API for more details. - """ - if not self.enabled or self._lock.locked() or not self.is_running: - return - - def _shutdown() -> None: - if self._otaproxy_subprocess and self._otaproxy_subprocess.is_alive(): - logger.info("shuting down otaproxy server process...") - self._otaproxy_subprocess.terminate() - self._otaproxy_subprocess.join() - self._otaproxy_subprocess = None - - async with self._lock: - await self._run_in_executor(_shutdown) - logger.info("otaproxy closed") diff --git a/src/otaclient/grpc/api_v2/ecu_status.py b/src/otaclient/grpc/api_v2/ecu_status.py index 5809e6e1b..66b8b68b3 100644 --- a/src/otaclient/grpc/api_v2/ecu_status.py +++ b/src/otaclient/grpc/api_v2/ecu_status.py @@ -41,19 +41,19 @@ import asyncio import logging +import math import time from itertools import chain -from typing import Dict, Iterable, Optional, Set, TypeVar +from typing import Dict, Iterable, Optional -from otaclient._types import OTAClientStatus +from otaclient._types import MultipleECUStatusFlags, OTAClientStatus from otaclient.configs.cfg import cfg, ecu_info from otaclient.grpc.api_v2.types import convert_to_apiv2_status from otaclient_api.v2 import types as api_types +from otaclient_common.typing import T logger = logging.getLogger(__name__) -T = TypeVar("T") - # NOTE(20230522): # ECU will be treated as disconnected if we cannot get in touch with it # longer than * . @@ -83,9 +83,15 @@ def discard(self, value: T): class ECUStatusStorage: - def __init__(self) -> None: + def __init__( + self, + *, + ecu_status_flags: MultipleECUStatusFlags, + ) -> None: self.my_ecu_id = ecu_info.ecu_id self._writer_lock = asyncio.Lock() + + self.ecu_status_flags = ecu_status_flags # ECU status storage self.storage_last_updated_timestamp = 0 @@ -113,24 +119,21 @@ def __init__(self) -> None: ecu_info.get_available_ecu_ids() ) - self._all_ecus_status_v2: Dict[str, api_types.StatusResponseEcuV2] = {} - self._all_ecus_status_v1: Dict[str, api_types.StatusResponseEcu] = {} - self._all_ecus_last_contact_timestamp: Dict[str, int] = {} + self._all_ecus_status_v2: dict[str, api_types.StatusResponseEcuV2] = {} + self._all_ecus_status_v1: dict[str, api_types.StatusResponseEcu] = {} + self._all_ecus_last_contact_timestamp: dict[str, int] = {} # overall ECU status report self._properties_update_lock = asyncio.Lock() self.properties_last_update_timestamp = 0 - self.active_ota_update_present = asyncio.Event() self.lost_ecus_id = set() self.failed_ecus_id = set() self.in_update_ecus_id = set() self.in_update_child_ecus_id = set() - self.any_requires_network = False self.success_ecus_id = set() - self.all_success = False # property update task # NOTE: _debug_properties_update_shutdown_event is for test only, @@ -157,6 +160,7 @@ async def _generate_overall_status_report(self): NOTE: as special case, lost_ecus set is calculated against all reachable ECUs. """ self.properties_last_update_timestamp = cur_timestamp = int(time.time()) + ecu_status_flags = self.ecu_status_flags # check unreachable ECUs # NOTE(20230801): this property is calculated against all reachable ECUs, @@ -192,9 +196,9 @@ async def _generate_overall_status_report(self): f"{_new_in_update_ecu}, current updating ECUs: {in_update_ecus_id}" ) if in_update_ecus_id: - self.active_ota_update_present.set() + ecu_status_flags.any_in_update.set() else: - self.active_ota_update_present.clear() + ecu_status_flags.any_in_update.clear() # check if there is any failed child/self ECU in tracked active ECUs set _old_failed_ecus_id = self.failed_ecus_id @@ -213,7 +217,7 @@ async def _generate_overall_status_report(self): ) # check if any ECUs in the tracked tracked active ECUs set require network - self.any_requires_network = any( + if any( ( status.requires_network for status in chain( @@ -222,10 +226,15 @@ async def _generate_overall_status_report(self): if status.ecu_id in self._tracked_active_ecus and status.ecu_id not in lost_ecus ) - ) + ): + ecu_status_flags.any_requires_network.set() + else: + ecu_status_flags.any_requires_network.clear() # check if all tracked active_ota_ecus are in SUCCESS ota_status - _old_all_success, _old_success_ecus_id = self.all_success, self.success_ecus_id + _old_all_success = ecu_status_flags.all_success.is_set() + _old_success_ecus_id = self.success_ecus_id + self.success_ecus_id = { status.ecu_id for status in chain( @@ -236,18 +245,16 @@ async def _generate_overall_status_report(self): and status.ecu_id not in lost_ecus } # NOTE: all_success doesn't count the lost ECUs - self.all_success = len(self.success_ecus_id) == len(self._tracked_active_ecus) + if len(self.success_ecus_id) == len(self._tracked_active_ecus): + ecu_status_flags.all_success.set() + else: + ecu_status_flags.all_success.clear() + if _new_success_ecu := self.success_ecus_id.difference(_old_success_ecus_id): logger.info(f"new succeeded ECU(s) detected: {_new_success_ecu}") - if not _old_all_success and self.all_success: + if ecu_status_flags.all_success.is_set() and not _old_all_success: logger.info("all ECUs in the cluster are in SUCCESS ota_status") - logger.debug( - "overall ECU status reporrt updated:" - f"{self.lost_ecus_id=}, {self.in_update_ecus_id=},{self.any_requires_network=}," - f"{self.failed_ecus_id=}, {self.success_ecus_id=}, {self.all_success=}" - ) - async def _loop_updating_properties(self): """ECU status storage's self generating overall ECU status report task. @@ -311,7 +318,7 @@ async def update_from_local_ecu(self, local_status: OTAClientStatus): self._all_ecus_status_v2[ecu_id] = convert_to_apiv2_status(local_status) self._all_ecus_last_contact_timestamp[ecu_id] = cur_timestamp - async def on_ecus_accept_update_request(self, ecus_accept_update: Set[str]): + async def on_ecus_accept_update_request(self, ecus_accept_update: set[str]): """Update overall ECU status report directly on ECU(s) accept OTA update request. for the ECUs that accepts OTA update request, we: @@ -325,6 +332,7 @@ async def on_ecus_accept_update_request(self, ecus_accept_update: Set[str]): their ota_status to UPDATING on-time due to status polling interval mismatch), the above set value will be kept for seconds. """ + ecu_status_flags = self.ecu_status_flags async with self._properties_update_lock: self._tracked_active_ecus = _OrderedSet(ecus_accept_update) @@ -334,12 +342,11 @@ async def on_ecus_accept_update_request(self, ecus_accept_update: Set[str]): self.in_update_ecus_id.update(ecus_accept_update) self.in_update_child_ecus_id = self.in_update_ecus_id - {self.my_ecu_id} - - self.any_requires_network = True - self.all_success = False self.success_ecus_id -= ecus_accept_update - self.active_ota_update_present.set() + ecu_status_flags.all_success.clear() + ecu_status_flags.any_requires_network.set() + ecu_status_flags.any_in_update.set() def get_polling_interval(self) -> int: """Return if there is active OTA update, @@ -348,9 +355,10 @@ def get_polling_interval(self) -> int: NOTE: use get_polling_waiter if want to wait, only call this method if one only wants to get the polling interval value. """ + ecu_status_flags = self.ecu_status_flags return ( ACTIVE_POLLING_INTERVAL - if self.active_ota_update_present.is_set() + if ecu_status_flags.any_in_update.is_set() else IDLE_POLLING_INTERVAL ) @@ -364,19 +372,20 @@ def get_polling_waiter(self): or self.active_ota_update_present is set, return when one of the condition is met. """ + # waiter closure will slice the waiting time by <_inner_wait_interval>, + # add wait each slice one by one while checking the ecu_status_flags. + _inner_wait_interval = 1 # second async def _waiter(): - if self.active_ota_update_present.is_set(): + ecu_status_flags = self.ecu_status_flags + if ecu_status_flags.any_in_update.is_set(): await asyncio.sleep(ACTIVE_POLLING_INTERVAL) return - try: - await asyncio.wait_for( - self.active_ota_update_present.wait(), - timeout=IDLE_POLLING_INTERVAL, - ) - except asyncio.TimeoutError: - return + for _ in range(math.ceil(IDLE_POLLING_INTERVAL / _inner_wait_interval)): + if ecu_status_flags.any_in_update.is_set(): + return + await asyncio.sleep(_inner_wait_interval) return _waiter diff --git a/src/otaclient/grpc/api_v2/ecu_tracker.py b/src/otaclient/grpc/api_v2/ecu_tracker.py index 7a3a8a3c8..ba454f138 100644 --- a/src/otaclient/grpc/api_v2/ecu_tracker.py +++ b/src/otaclient/grpc/api_v2/ecu_tracker.py @@ -17,16 +17,34 @@ from __future__ import annotations import asyncio +import atexit import logging +from collections import defaultdict -from otaclient._status_monitor import OTAClientStatusCollector +from otaclient._utils import SharedOTAClientStatusReader from otaclient.configs import ECUContact from otaclient.configs.cfg import cfg, ecu_info from otaclient.grpc.api_v2.ecu_status import ECUStatusStorage from otaclient_api.v2 import types as api_types from otaclient_api.v2.api_caller import ECUNoResponse, OTAClientCall +from otaclient_common.logging import BurstSuppressFilter logger = logging.getLogger(__name__) +burst_suppressed_logger = logging.getLogger(f"{__name__}.local_ecu_check") +# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds +burst_suppressed_logger.addFilter( + BurstSuppressFilter( + f"{__name__}.local_ecu_check", + upper_logger_name=__name__, + burst_round_length=30, + burst_max=6, + ) +) + +# actively polling ECUs status until we get the first valid response +# when otaclient is just starting. +_ACTIVE_POLL_SUB_ON_STARTUP = 1 +_ACTIVE_POLL_LOCAL_ON_STARTUP = 0.1 class ECUTracker: @@ -34,22 +52,20 @@ class ECUTracker: def __init__( self, ecu_status_storage: ECUStatusStorage, - *, - local_status_collector: OTAClientStatusCollector, + /, + local_ecu_status_reader: SharedOTAClientStatusReader, ) -> None: - self._local_status_collector = local_status_collector + self._local_ecu_status_reader = local_ecu_status_reader self._ecu_status_storage = ecu_status_storage self._polling_waiter = self._ecu_status_storage.get_polling_waiter() + self._startup_matrix: defaultdict[str, bool] = defaultdict(lambda: True) - # launch ECU trackers for all defined ECUs - # NOTE: _debug_ecu_status_polling_shutdown_event is for test only, - # allow us to stop background task without changing codes. - # In normal running this event will never be set. - self._debug_ecu_status_polling_shutdown_event = asyncio.Event() + atexit.register(local_ecu_status_reader.atexit) async def _polling_direct_subecu_status(self, ecu_contact: ECUContact): """Task entry for loop polling one subECU's status.""" - while not self._debug_ecu_status_polling_shutdown_event.is_set(): + this_ecu_id = ecu_contact.ecu_id + while True: try: _ecu_resp = await OTAClientCall.status_call( ecu_contact.ecu_id, @@ -58,20 +74,40 @@ async def _polling_direct_subecu_status(self, ecu_contact: ECUContact): timeout=cfg.QUERYING_SUBECU_STATUS_TIMEOUT, request=api_types.StatusRequest(), ) + if self._startup_matrix[this_ecu_id] and ( + _ecu_resp.find_ecu_v2(this_ecu_id) + or _ecu_resp.find_ecu(this_ecu_id) + ): + self._startup_matrix[this_ecu_id] = False await self._ecu_status_storage.update_from_child_ecu(_ecu_resp) except ECUNoResponse as e: logger.debug( f"ecu@{ecu_contact} doesn't respond to status request: {e!r}" ) - await self._polling_waiter() + + if self._startup_matrix[this_ecu_id]: + await asyncio.sleep(_ACTIVE_POLL_SUB_ON_STARTUP) + else: + await self._polling_waiter() async def _polling_local_ecu_status(self): """Task entry for loop polling local ECU status.""" - while not self._debug_ecu_status_polling_shutdown_event.is_set(): - status_report = self._local_status_collector.otaclient_status - if status_report: + my_ecu_id = ecu_info.ecu_id + while True: + try: + status_report = self._local_ecu_status_reader.sync_msg() + if status_report: + self._startup_matrix[my_ecu_id] = False await self._ecu_status_storage.update_from_local_ecu(status_report) - await self._polling_waiter() + except Exception as e: + burst_suppressed_logger.debug( + f"failed to query local ECU's status: {e!r}" + ) + + if self._startup_matrix[my_ecu_id]: + await asyncio.sleep(_ACTIVE_POLL_LOCAL_ON_STARTUP) + else: + await self._polling_waiter() def start(self) -> None: asyncio.create_task(self._polling_local_ecu_status()) diff --git a/src/otaclient/grpc/api_v2/main.py b/src/otaclient/grpc/api_v2/main.py new file mode 100644 index 000000000..38ca6481d --- /dev/null +++ b/src/otaclient/grpc/api_v2/main.py @@ -0,0 +1,87 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Main entry for OTA API v2 grpc server.""" + + +from __future__ import annotations + +import asyncio +import atexit +import logging +from concurrent.futures import ThreadPoolExecutor +from multiprocessing.queues import Queue as mp_Queue +from typing import Callable, NoReturn + +from otaclient._types import IPCRequest, IPCResponse, MultipleECUStatusFlags +from otaclient._utils import SharedOTAClientStatusReader + +logger = logging.getLogger(__name__) + + +def grpc_server_process( + *, + shm_reader_factory: Callable[[], SharedOTAClientStatusReader], + op_queue: mp_Queue[IPCRequest], + resp_queue: mp_Queue[IPCResponse], + ecu_status_flags: MultipleECUStatusFlags, +) -> NoReturn: # type: ignore + from otaclient._logging import configure_logging + + configure_logging() + logger.info("otaclient OTA API grpc server started") + + shm_reader = shm_reader_factory() + atexit.register(shm_reader.atexit) + + async def _grpc_server_launcher(): + import grpc.aio + + from otaclient.configs.cfg import cfg, ecu_info + from otaclient.grpc.api_v2.ecu_status import ECUStatusStorage + from otaclient.grpc.api_v2.ecu_tracker import ECUTracker + from otaclient.grpc.api_v2.servicer import OTAClientAPIServicer + from otaclient_api.v2 import otaclient_v2_pb2_grpc as v2_grpc + from otaclient_api.v2.api_stub import OtaClientServiceV2 + + ecu_status_storage = ECUStatusStorage(ecu_status_flags=ecu_status_flags) + ecu_tracker = ECUTracker(ecu_status_storage, shm_reader) + ecu_tracker.start() + + thread_pool = ThreadPoolExecutor( + thread_name_prefix="ota_api_server", + ) + api_servicer = OTAClientAPIServicer( + ecu_status_storage=ecu_status_storage, + op_queue=op_queue, + resp_queue=resp_queue, + executor=thread_pool, + ) + ota_client_service_v2 = OtaClientServiceV2(api_servicer) + + server = grpc.aio.server(migration_thread_pool=thread_pool) + v2_grpc.add_OtaClientServiceServicer_to_server( + server=server, servicer=ota_client_service_v2 + ) + _address_info = f"{ecu_info.ip_addr}:{cfg.OTA_API_SERVER_PORT}" + server.add_insecure_port(_address_info) + + logger.info(f"launch grpc API server at {_address_info}") + await server.start() + try: + await server.wait_for_termination() + finally: + await server.stop(1) + thread_pool.shutdown(wait=True) + + asyncio.run(_grpc_server_launcher()) diff --git a/src/otaclient/grpc/api_v2/servicer.py b/src/otaclient/grpc/api_v2/servicer.py index 65c8f9c58..a44ccb1a5 100644 --- a/src/otaclient/grpc/api_v2/servicer.py +++ b/src/otaclient/grpc/api_v2/servicer.py @@ -18,22 +18,27 @@ import asyncio import logging -import time +import multiprocessing.queues as mp_queue from concurrent.futures import ThreadPoolExecutor -from functools import partial -from typing import Dict +from otaclient._types import ( + IPCRequest, + IPCResEnum, + IPCResponse, + RollbackRequestV2, + UpdateRequestV2, +) +from otaclient._utils import gen_session_id from otaclient.configs import ECUContact -from otaclient.configs.cfg import cfg, ecu_info, proxy_info -from otaclient.grpc._otaproxy_ctx import OTAProxyLauncher +from otaclient.configs.cfg import cfg, ecu_info from otaclient.grpc.api_v2.ecu_status import ECUStatusStorage -from otaclient.grpc.api_v2.types import convert_from_apiv2_update_request -from otaclient.ota_core import OTAClient, OTAClientControlFlags from otaclient_api.v2 import types as api_types from otaclient_api.v2.api_caller import ECUNoResponse, OTAClientCall logger = logging.getLogger(__name__) +WAIT_FOR_LOCAL_ECU_ACK_TIMEOUT = 6 # seconds + class OTAClientAPIServicer: """Handlers for otaclient service API. @@ -41,103 +46,63 @@ class OTAClientAPIServicer: This class also handles otaproxy lifecyle and dependence managing. """ - OTAPROXY_SHUTDOWN_DELAY = cfg.OTAPROXY_MINIMUM_SHUTDOWN_INTERVAL - def __init__( self, - otaclient_inst: OTAClient, - ecu_status_storage: ECUStatusStorage, *, - control_flag: OTAClientControlFlags, + ecu_status_storage: ECUStatusStorage, + op_queue: mp_queue.Queue[IPCRequest], + resp_queue: mp_queue.Queue[IPCResponse], executor: ThreadPoolExecutor, ): - self._executor = executor - self._run_in_executor = partial( - asyncio.get_running_loop().run_in_executor, executor - ) - self.sub_ecus = ecu_info.secondaries self.listen_addr = ecu_info.ip_addr self.listen_port = cfg.OTA_API_SERVER_PORT self.my_ecu_id = ecu_info.ecu_id + self._executor = executor - self._otaclient_control_flags = control_flag - self._otaclient_inst = otaclient_inst + self._op_queue = op_queue + self._resp_queue = resp_queue self._ecu_status_storage = ecu_status_storage self._polling_waiter = self._ecu_status_storage.get_polling_waiter() - # otaproxy lifecycle and dependency managing - # NOTE: _debug_status_checking_shutdown_event is for test only, - # allow us to stop background task without changing codes. - # In normal running this event will never be set. - self._debug_status_checking_shutdown_event = asyncio.Event() - if proxy_info.enable_local_ota_proxy: - self._otaproxy_launcher = OTAProxyLauncher(executor=executor) - asyncio.create_task(self._otaproxy_lifecycle_managing()) - asyncio.create_task(self._otaclient_control_flags_managing()) - else: - # if otaproxy is not enabled, no dependency relationship will be formed, - # always allow local otaclient to reboot - self._otaclient_control_flags.set_can_reboot_flag() - - # internal - - async def _otaproxy_lifecycle_managing(self): - """Task entry for managing otaproxy's launching/shutdown. + # API servicer - NOTE: cache_dir cleanup is handled here, when all ECUs are in SUCCESS ota_status, - cache_dir will be removed. - """ - otaproxy_last_launched_timestamp = 0 - while not self._debug_status_checking_shutdown_event.is_set(): - cur_timestamp = int(time.time()) - any_requires_network = self._ecu_status_storage.any_requires_network - if self._otaproxy_launcher.is_running: - # NOTE: do not shutdown otaproxy too quick after it just starts! - # If otaproxy just starts less than seconds, - # skip the shutdown this time. - if ( - not any_requires_network - and cur_timestamp - > otaproxy_last_launched_timestamp + self.OTAPROXY_SHUTDOWN_DELAY - ): - await self._otaproxy_launcher.stop() - otaproxy_last_launched_timestamp = 0 - else: # otaproxy is not running - if any_requires_network: - await self._otaproxy_launcher.start(init_cache=False) - otaproxy_last_launched_timestamp = cur_timestamp - # when otaproxy is not running and any_requires_network is False, - # cleanup the cache dir when all ECUs are in SUCCESS ota_status - elif self._ecu_status_storage.all_success: - self._otaproxy_launcher.cleanup_cache_dir() - await self._polling_waiter() + def _local_update(self, request: UpdateRequestV2) -> api_types.UpdateResponseEcu: + """Thread worker for dispatching a local update.""" + self._op_queue.put_nowait(request) + try: + _req_response = self._resp_queue.get(timeout=WAIT_FOR_LOCAL_ECU_ACK_TIMEOUT) + assert isinstance(_req_response, IPCResponse), "unexpected msg" + assert ( + _req_response.session_id == request.session_id + ), "mismatched session_id" - async def _otaclient_control_flags_managing(self): - """Task entry for set/clear otaclient control flags. - - Prevent self ECU from rebooting when their is at least one ECU - under UPDATING ota_status. - """ - while not self._debug_status_checking_shutdown_event.is_set(): - _can_reboot = self._otaclient_control_flags.is_can_reboot_flag_set() - if not self._ecu_status_storage.in_update_child_ecus_id: - if not _can_reboot: - logger.info( - "local otaclient can reboot as no child ECU is in UPDATING ota_status" - ) - self._otaclient_control_flags.set_can_reboot_flag() + if _req_response.res == IPCResEnum.ACCEPT: + return api_types.UpdateResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.NO_FAILURE, + ) else: - if _can_reboot: - logger.info( - f"local otaclient cannot reboot as child ECUs {self._ecu_status_storage.in_update_child_ecus_id}" - " are in UPDATING ota_status" - ) - self._otaclient_control_flags.clear_can_reboot_flag() - await self._polling_waiter() - - # API stub + logger.error( + f"local otaclient doesn't accept upate request: {_req_response.msg}" + ) + return api_types.UpdateResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.RECOVERABLE, + ) + except AssertionError as e: + logger.error(f"local otaclient response with unexpected msg: {e!r}") + return api_types.UpdateResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.RECOVERABLE, + ) + except Exception as e: # failed to get ACK from otaclient within timeout + logger.error(f"local otaclient failed to ACK request: {e!r}") + return api_types.UpdateResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.UNRECOVERABLE, + ) async def update( self, request: api_types.UpdateRequest @@ -147,7 +112,7 @@ async def update( response = api_types.UpdateResponse() # first: dispatch update request to all directly connected subECUs - tasks: Dict[asyncio.Task, ECUContact] = {} + tasks: dict[asyncio.Task, ECUContact] = {} for ecu_contact in self.sub_ecus: if not request.if_contains_ecu(ecu_contact.ecu_id): continue @@ -187,35 +152,21 @@ async def update( # second: dispatch update request to local if required by incoming request if update_req_ecu := request.find_ecu(self.my_ecu_id): - if not self._otaclient_inst.started: - logger.error("otaclient is not running, abort") - response.add_ecu( - api_types.UpdateResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.UNRECOVERABLE, - ) - ) - elif self._otaclient_inst.is_busy: - response.add_ecu( - api_types.UpdateResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.RECOVERABLE, - ) - ) - else: - self._run_in_executor( - self._otaclient_inst.update, - convert_from_apiv2_update_request(update_req_ecu), - ).add_done_callback( - lambda _: logger.info("update execution thread finished") - ) + new_session_id = gen_session_id(update_req_ecu.version) + _resp = await asyncio.get_running_loop().run_in_executor( + self._executor, + self._local_update, + UpdateRequestV2( + version=update_req_ecu.version, + url_base=update_req_ecu.url, + cookies_json=update_req_ecu.cookies, + session_id=new_session_id, + ), + ) + + if _resp.result == api_types.FailureType.NO_FAILURE: update_acked_ecus.add(self.my_ecu_id) - response.add_ecu( - api_types.UpdateResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.NO_FAILURE, - ) - ) + response.add_ecu(_resp) # finally, trigger ecu_status_storage entering active mode if needed if update_acked_ecus: @@ -227,6 +178,47 @@ async def update( ) return response + def _local_rollback( + self, rollback_request: RollbackRequestV2 + ) -> api_types.RollbackResponseEcu: + """Thread worker for dispatching a local rollback.""" + + self._op_queue.put_nowait(rollback_request) + try: + _req_response = self._resp_queue.get(timeout=WAIT_FOR_LOCAL_ECU_ACK_TIMEOUT) + assert isinstance( + _req_response, IPCResponse + ), f"unexpected response: {type(_req_response)}" + assert ( + _req_response.session_id == rollback_request.session_id + ), "mismatched session_id" + + if _req_response.res == IPCResEnum.ACCEPT: + return api_types.RollbackResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.NO_FAILURE, + ) + else: + logger.error( + f"local otaclient doesn't accept upate request: {_req_response.msg}" + ) + return api_types.RollbackResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.RECOVERABLE, + ) + except AssertionError as e: + logger.error(f"local otaclient response with unexpected msg: {e!r}") + return api_types.RollbackResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.RECOVERABLE, + ) + except Exception as e: # failed to get ACK from otaclient within timeout + logger.error(f"local otaclient failed to ACK request: {e!r}") + return api_types.RollbackResponseEcu( + ecu_id=self.my_ecu_id, + result=api_types.FailureType.UNRECOVERABLE, + ) + async def rollback( self, request: api_types.RollbackRequest ) -> api_types.RollbackResponse: @@ -234,7 +226,7 @@ async def rollback( response = api_types.RollbackResponse() # first: dispatch rollback request to all directly connected subECUs - tasks: Dict[asyncio.Task, ECUContact] = {} + tasks: dict[asyncio.Task, ECUContact] = {} for ecu_contact in self.sub_ecus: if not request.if_contains_ecu(ecu_contact.ecu_id): continue @@ -273,31 +265,13 @@ async def rollback( # second: dispatch rollback request to local if required if request.find_ecu(self.my_ecu_id): - if not self._otaclient_inst.started: - logger.error("otaclient is not running, abort") - response.add_ecu( - api_types.RollbackResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.UNRECOVERABLE, - ) - ) - elif self._otaclient_inst.is_busy: - response.add_ecu( - api_types.RollbackResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.RECOVERABLE, - ) - ) - else: - self._run_in_executor(self._otaclient_inst.rollback).add_done_callback( - lambda _: logger.info("rollback execution thread finished") - ) - response.add_ecu( - api_types.RollbackResponseEcu( - ecu_id=self.my_ecu_id, - result=api_types.FailureType.NO_FAILURE, - ) - ) + new_session_id = gen_session_id("__rollback") + _local_resp = await asyncio.get_running_loop().run_in_executor( + self._executor, + self._local_rollback, + RollbackRequestV2(session_id=new_session_id), + ) + response.add_ecu(_local_resp) return response async def status(self, _=None) -> api_types.StatusResponse: diff --git a/src/otaclient/grpc/api_v2/types.py b/src/otaclient/grpc/api_v2/types.py index 9386966ac..208672248 100644 --- a/src/otaclient/grpc/api_v2/types.py +++ b/src/otaclient/grpc/api_v2/types.py @@ -18,7 +18,7 @@ import time -from otaclient._types import OTAClientStatus, OTAStatus, UpdateRequestV2, UpdateTiming +from otaclient._types import OTAClientStatus, OTAStatus, UpdateTiming from otaclient_api.v2 import types as api_types from otaclient_common.proto_wrapper import Duration @@ -101,13 +101,3 @@ def convert_to_apiv2_status(_in: OTAClientStatus) -> api_types.StatusResponseEcu base_res.update_status = update_status return base_res - - -def convert_from_apiv2_update_request( - _in: api_types.UpdateRequestEcu, -) -> UpdateRequestV2: - return UpdateRequestV2( - version=_in.version, - url_base=_in.url, - cookies_json=_in.cookies, - ) diff --git a/src/otaclient/main.py b/src/otaclient/main.py index bc34fdc9c..4d0316993 100644 --- a/src/otaclient/main.py +++ b/src/otaclient/main.py @@ -16,76 +16,74 @@ from __future__ import annotations -import asyncio +import atexit import logging -from concurrent.futures import ThreadPoolExecutor -from queue import Queue +import multiprocessing as mp +import multiprocessing.context as mp_ctx +import multiprocessing.shared_memory as mp_shm +import secrets +import signal +import sys +import threading +import time +from functools import partial from otaclient import __version__ +from otaclient._types import MultipleECUStatusFlags +from otaclient._utils import SharedOTAClientStatusReader, SharedOTAClientStatusWriter logger = logging.getLogger(__name__) +HEALTH_CHECK_INTERAVL = 6 # seconds +# NOTE: the reason to let daemon_process exits after 16 seconds of ota_core dead +# is to allow grpc API server to respond to the status API calls with up-to-date +# failure information from ota_core. +SHUTDOWN_AFTER_CORE_EXIT = 16 # seconds +SHUTDOWN_AFTER_API_SERVER_EXIT = 3 # seconds -async def create_otaclient_grpc_server(): - import grpc.aio +STATUS_SHM_SIZE = 4096 # bytes +MAX_TRACEBACK_SIZE = 2048 # bytes +SHM_HMAC_KEY_LEN = 64 # bytes - from otaclient._status_monitor import OTAClientStatusCollector - from otaclient.configs.cfg import cfg, ecu_info, proxy_info - from otaclient.grpc.api_v2.ecu_status import ECUStatusStorage - from otaclient.grpc.api_v2.ecu_tracker import ECUTracker - from otaclient.grpc.api_v2.servicer import OTAClientAPIServicer - from otaclient.ota_core import OTAClient, OTAClientControlFlags - from otaclient_api.v2 import otaclient_v2_pb2_grpc as v2_grpc - from otaclient_api.v2.api_stub import OtaClientServiceV2 - - _executor = ThreadPoolExecutor(thread_name_prefix="otaclient_main") - _control_flag = OTAClientControlFlags() - - status_report_queue = Queue() - status_collector = OTAClientStatusCollector(status_report_queue) - - ecu_status_storage = ECUStatusStorage() - ecu_tracker = ECUTracker( - ecu_status_storage, - local_status_collector=status_collector, - ) - ecu_tracker.start() +_ota_core_p: mp_ctx.SpawnProcess | None = None +_grpc_server_p: mp_ctx.SpawnProcess | None = None +_shm: mp_shm.SharedMemory | None = None - otaclient_inst = OTAClient( - control_flags=_control_flag, - proxy=proxy_info.get_proxy_for_local_ota(), - status_report_queue=status_report_queue, - ) - status_collector.start() - service_stub = OTAClientAPIServicer( - otaclient_inst, - ecu_status_storage, - control_flag=_control_flag, - executor=_executor, - ) - ota_client_service_v2 = OtaClientServiceV2(service_stub) - server = grpc.aio.server() - v2_grpc.add_OtaClientServiceServicer_to_server( - server=server, servicer=ota_client_service_v2 - ) - server.add_insecure_port(f"{ecu_info.ip_addr}:{cfg.OTA_API_SERVER_PORT}") - return server +def _on_shutdown(sys_exit: bool = False) -> None: # pragma: no cover + global _ota_core_p, _grpc_server_p, _shm + if _ota_core_p: + _ota_core_p.terminate() + _ota_core_p.join() + _ota_core_p = None + + if _grpc_server_p: + _grpc_server_p.terminate() + _grpc_server_p.join() + _grpc_server_p = None + + if _shm: + _shm.close() + _shm.unlink() + _shm = None + if sys_exit: + sys.exit(1) -async def launch_otaclient_grpc_server(): - server = await create_otaclient_grpc_server() - await server.start() - try: - await server.wait_for_termination() - finally: - await server.stop(1) +def _signal_handler(signal_value, _) -> None: # pragma: no cover + print(f"otaclient receives {signal_value=}, shutting down ...") + # NOTE: the daemon_process needs to exit also. + _on_shutdown(sys_exit=True) -def main() -> None: + +def main() -> None: # pragma: no cover from otaclient._logging import configure_logging - from otaclient.configs.cfg import cfg, ecu_info - from otaclient.utils import check_other_otaclient, create_otaclient_rundir + from otaclient._otaproxy_ctx import otaproxy_control_thread + from otaclient._utils import check_other_otaclient, create_otaclient_rundir + from otaclient.configs.cfg import cfg, ecu_info, proxy_info + from otaclient.grpc.api_v2.main import grpc_server_process + from otaclient.ota_core import ota_core_process # configure logging before any code being executed configure_logging() @@ -97,4 +95,88 @@ def main() -> None: check_other_otaclient(cfg.OTACLIENT_PID_FILE) create_otaclient_rundir(cfg.RUN_DIR) - asyncio.run(launch_otaclient_grpc_server()) + # + # ------ start each processes ------ # + # + global _ota_core_p, _grpc_server_p, _shm + + # NOTE: if the atexit hook is triggered by signal received, + # first the signal handler will be executed, and then atexit hook. + # At the time atexit hook is executed, the _ota_core_p, _grpc_server_p + # and _shm are set to None by signal handler. + atexit.register(_on_shutdown) + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + + mp_ctx = mp.get_context("spawn") + _shm = mp_shm.SharedMemory(size=STATUS_SHM_SIZE, create=True) + _key = secrets.token_bytes(SHM_HMAC_KEY_LEN) + + # shared queues and flags + local_otaclient_op_queue = mp_ctx.Queue() + local_otaclient_resp_queue = mp_ctx.Queue() + ecu_status_flags = MultipleECUStatusFlags( + any_in_update=mp_ctx.Event(), + any_requires_network=mp_ctx.Event(), + all_success=mp_ctx.Event(), + ) + + _ota_core_p = mp_ctx.Process( + target=partial( + ota_core_process, + shm_writer_factory=partial( + SharedOTAClientStatusWriter, name=_shm.name, key=_key + ), + ecu_status_flags=ecu_status_flags, + op_queue=local_otaclient_op_queue, + resp_queue=local_otaclient_resp_queue, + max_traceback_size=MAX_TRACEBACK_SIZE, + ), + name="otaclient_ota_core", + ) + _ota_core_p.start() + + _grpc_server_p = mp_ctx.Process( + target=partial( + grpc_server_process, + shm_reader_factory=partial( + SharedOTAClientStatusReader, name=_shm.name, key=_key + ), + op_queue=local_otaclient_op_queue, + resp_queue=local_otaclient_resp_queue, + ecu_status_flags=ecu_status_flags, + ), + name="otaclient_api_server", + ) + _grpc_server_p.start() + + del _key + + # ------ setup main process ------ # + + _otaproxy_control_t = None + if proxy_info.enable_local_ota_proxy: + _otaproxy_control_t = threading.Thread( + target=partial(otaproxy_control_thread, ecu_status_flags), + daemon=True, + name="otaclient_otaproxy_control_t", + ) + _otaproxy_control_t.start() + + while True: + time.sleep(HEALTH_CHECK_INTERAVL) + + if not _ota_core_p.is_alive(): + logger.error( + "ota_core process is dead! " + f"otaclient will exit in {SHUTDOWN_AFTER_CORE_EXIT}seconds ..." + ) + time.sleep(SHUTDOWN_AFTER_CORE_EXIT) + return _on_shutdown() + + if not _grpc_server_p.is_alive(): + logger.error( + f"ota API server is dead, whole otaclient will exit in {SHUTDOWN_AFTER_API_SERVER_EXIT}seconds ..." + ) + time.sleep(SHUTDOWN_AFTER_API_SERVER_EXIT) + return _on_shutdown() diff --git a/src/otaclient/ota_core.py b/src/otaclient/ota_core.py index 83fe51dfd..77696ff64 100644 --- a/src/otaclient/ota_core.py +++ b/src/otaclient/ota_core.py @@ -18,7 +18,9 @@ import errno import json import logging -import os +import multiprocessing.queues as mp_queue +import signal +import sys import threading import time from concurrent.futures import Future @@ -27,8 +29,8 @@ from http import HTTPStatus from json.decoder import JSONDecodeError from pathlib import Path -from queue import Queue -from typing import Any, Iterator, Optional, Type +from queue import Empty, Queue +from typing import Any, Callable, Iterator, NoReturn, Optional, Type from urllib.parse import urlparse import requests.exceptions as requests_exc @@ -43,6 +45,7 @@ ) from otaclient import errors as ota_errors from otaclient._status_monitor import ( + OTAClientStatusCollector, OTAStatusChangeReport, OTAUpdatePhaseChangeReport, SetOTAClientMetaReport, @@ -50,15 +53,25 @@ StatusReport, UpdateProgressReport, ) -from otaclient._types import FailureType, OTAStatus, UpdatePhase, UpdateRequestV2 +from otaclient._types import ( + FailureType, + IPCRequest, + IPCResEnum, + IPCResponse, + MultipleECUStatusFlags, + OTAStatus, + RollbackRequestV2, + UpdatePhase, + UpdateRequestV2, +) +from otaclient._utils import SharedOTAClientStatusWriter, get_traceback, wait_and_log from otaclient.boot_control import BootControllerProtocol, get_boot_controller -from otaclient.configs.cfg import cfg, ecu_info +from otaclient.configs.cfg import cfg, ecu_info, proxy_info from otaclient.create_standby import ( StandbySlotCreatorProtocol, get_standby_slot_creator, ) from otaclient.create_standby.common import DeltaBundle -from otaclient.utils import get_traceback, wait_and_log from otaclient_common.common import ensure_otaproxy_start from otaclient_common.downloader import ( EMPTY_FILE_SHA256, @@ -79,32 +92,12 @@ DOWNLOAD_STATS_REPORT_BATCH = 300 DOWNLOAD_REPORT_INTERVAL = 1 # second +OP_CHECK_INTERVAL = 1 # second +HOLD_REQ_HANDLING_ON_ACK_REQUEST = 16 # seconds +WAIT_FOR_OTAPROXY_ONLINE = 3 * 60 # 3mins -class OTAClientError(Exception): ... - - -class OTAClientControlFlags: - """ - When self ECU's otaproxy is enabled, all the child ECUs of this ECU - and self ECU OTA update will depend on its otaproxy, we need to - control when otaclient can start its downloading/reboot with considering - whether local otaproxy is started/required. - """ - - def __init__(self) -> None: - self._can_reboot = threading.Event() - - def is_can_reboot_flag_set(self) -> bool: - return self._can_reboot.is_set() - - def wait_can_reboot_flag(self): - self._can_reboot.wait() - def set_can_reboot_flag(self): - self._can_reboot.set() - - def clear_can_reboot_flag(self): - self._can_reboot.clear() +class OTAClientError(Exception): ... def _download_exception_handler(_fut: Future[Any]) -> bool: @@ -172,14 +165,34 @@ def __init__( upper_otaproxy: str | None = None, boot_controller: BootControllerProtocol, create_standby_cls: Type[StandbySlotCreatorProtocol], - control_flags: OTAClientControlFlags, + ecu_status_flags: MultipleECUStatusFlags, status_report_queue: Queue[StatusReport], session_id: str, ) -> None: + self.update_version = version + self.update_start_timestamp = int(time.time()) self.ca_chains_store = ca_chains_store self.session_id = session_id self._status_report_queue = status_report_queue + status_report_queue.put_nowait( + StatusReport( + payload=OTAUpdatePhaseChangeReport( + new_update_phase=UpdatePhase.INITIALIZING, + trigger_timestamp=self.update_start_timestamp, + ), + session_id=session_id, + ) + ) + status_report_queue.put_nowait( + StatusReport( + payload=SetUpdateMetaReport( + update_firmware_version=version, + ), + session_id=session_id, + ) + ) + # ------ define OTA temp paths ------ # self._ota_tmp_on_standby = Path(cfg.STANDBY_SLOT_MNT) / Path( cfg.OTA_TMP_STORE @@ -202,47 +215,13 @@ def __init__( # ------ parse upper proxy ------ # logger.debug("configure proxy setting...") - proxies = {} - if upper_otaproxy: - logger.info( - f"use {upper_otaproxy} for local OTA update, " - f"wait for otaproxy@{upper_otaproxy} online..." - ) - ensure_otaproxy_start( - upper_otaproxy, - probing_timeout=cfg.DOWNLOAD_INACTIVE_TIMEOUT, - ) - # NOTE(20221013): check requests document for how to set proxy, - # we only support using http proxy here. - proxies["http"] = upper_otaproxy + self._upper_proxy = upper_otaproxy # ------ init updater implementation ------ # - self._control_flags = control_flags + self.ecu_status_flags = ecu_status_flags self._boot_controller = boot_controller self._create_standby_cls = create_standby_cls - # ------ init update status ------ # - self.update_version = version - self.update_start_timestamp = int(time.time()) - - status_report_queue.put_nowait( - StatusReport( - payload=OTAUpdatePhaseChangeReport( - new_update_phase=UpdatePhase.INITIALIZING, - trigger_timestamp=self.update_start_timestamp, - ), - session_id=self.session_id, - ) - ) - status_report_queue.put_nowait( - StatusReport( - payload=SetUpdateMetaReport( - update_firmware_version=version, - ), - session_id=self.session_id, - ) - ) - # ------ init variables needed for update ------ # _url_base = urlparse(raw_url_base) _path = f"{_url_base.path.rstrip('/')}/" @@ -254,7 +233,9 @@ def __init__( hash_func=sha256, chunk_size=cfg.CHUNK_SIZE, cookies=cookies, - proxies=proxies, + # NOTE(20221013): check requests document for how to set proxy, + # we only support using http proxy here. + proxies={"http": upper_otaproxy} if upper_otaproxy else None, ) self._downloader_mapper: dict[int, Downloader] = {} @@ -427,6 +408,18 @@ def _execute_update(self): """Implementation of OTA updating.""" logger.info(f"execute local update({ecu_info.ecu_id=}): {self.update_version=}") + if _upper_proxy := self._upper_proxy: + logger.info( + f"use {_upper_proxy} for local OTA update, " + f"wait for otaproxy@{_upper_proxy} online..." + ) + + # NOTE: will raise a built-in ConnnectionError at timeout + ensure_otaproxy_start( + _upper_proxy, + probing_timeout=WAIT_FOR_OTAPROXY_ONLINE, + ) + # ------ init, processing metadata ------ # logger.debug("process metadata.jwt...") self._status_report_queue.put_nowait( @@ -536,8 +529,10 @@ def _execute_update(self): try: self._download_files(otameta, delta_bundle.get_download_list()) except TasksEnsureFailed: - # NOTE: the only cause of a TaskEnsureFailed being raised is the download_watchdog timeout. - _err_msg = f"download stalls longer than {cfg.DOWNLOAD_INACTIVE_TIMEOUT}, abort OTA" + _err_msg = ( + "download aborted due to download stalls longer than " + f"{cfg.DOWNLOAD_INACTIVE_TIMEOUT}, or otaclient process is terminated, abort OTA" + ) logger.error(_err_msg) raise ota_errors.NetworkError(_err_msg, module=__name__) from None finally: @@ -583,11 +578,16 @@ def _execute_update(self): session_id=self.session_id, ) ) - wait_and_log( - flag=self._control_flags._can_reboot, - message="permit reboot flag", - log_func=logger.info, - ) + + # NOTE: we don't need to wait for sub ECUs if sub ECUs don't + # depend on otaproxy on this ECU. + if proxy_info.enable_local_ota_proxy: + wait_and_log( + check_flag=self.ecu_status_flags.any_requires_network.is_set, + check_for=False, + message="permit reboot flag", + log_func=logger.info, + ) logger.info(f"device will reboot in {WAIT_BEFORE_REBOOT} seconds!") time.sleep(WAIT_BEFORE_REBOOT) @@ -628,25 +628,17 @@ def execute(self): class OTAClient: - """ - Init params: - boot_controller: boot control instance - create_standby_cls: type of create standby slot mechanism to use - my_ecu_id: ECU id of the device running this otaclient instance - control_flags: flags used by otaclient and ota_service stub for synchronization - proxy: upper otaproxy URL - """ def __init__( self, *, - control_flags: OTAClientControlFlags, + ecu_status_flags: MultipleECUStatusFlags, proxy: Optional[str] = None, status_report_queue: Queue[StatusReport], ) -> None: self.my_ecu_id = ecu_info.ecu_id self.proxy = proxy - self.control_flags = control_flags + self.ecu_status_flags = ecu_status_flags self._status_report_queue = status_report_queue self._live_ota_status = OTAStatus.INITIALIZED @@ -709,17 +701,6 @@ def __init__( self.started = True logger.info("otaclient started") - def _gen_session_id(self, update_version: str = "") -> str: - """Generate a unique session_id for the new OTA session. - - token schema: - --<4bytes_hex> - """ - _time_factor = str(int(time.time())) - _random_factor = os.urandom(4).hex() - - return f"{update_version}-{_time_factor}-{_random_factor}" - def _on_failure( self, exc: Exception, @@ -762,10 +743,8 @@ def update(self, request: UpdateRequestV2) -> None: NOTE that update API will not raise any exceptions. The failure information is available via status API. """ - if self.is_busy: - return - - new_session_id = self._gen_session_id(request.version) + self._live_ota_status = OTAStatus.UPDATING + new_session_id = request.session_id self._status_report_queue.put_nowait( StatusReport( payload=OTAStatusChangeReport( @@ -784,7 +763,6 @@ def update(self, request: UpdateRequestV2) -> None: module=__name__, ) - self._live_ota_status = OTAStatus.UPDATING _OTAUpdater( version=request.version, raw_url_base=request.url_base, @@ -792,7 +770,7 @@ def update(self, request: UpdateRequestV2) -> None: ca_chains_store=self.ca_chains_store, boot_controller=self.boot_controller, create_standby_cls=self.create_standby_cls, - control_flags=self.control_flags, + ecu_status_flags=self.ecu_status_flags, upper_otaproxy=self.proxy, status_report_queue=self._status_report_queue, session_id=new_session_id, @@ -806,11 +784,9 @@ def update(self, request: UpdateRequestV2) -> None: failure_type=e.failure_type, ) - def rollback(self) -> None: - if self.is_busy: - return - - new_session_id = self._gen_session_id("___rollback") + def rollback(self, request: RollbackRequestV2) -> None: + self._live_ota_status = OTAStatus.ROLLBACKING + new_session_id = request.session_id self._status_report_queue.put_nowait( StatusReport( payload=OTAStatusChangeReport( @@ -819,17 +795,133 @@ def rollback(self) -> None: session_id=new_session_id, ) ) - logger.info(f"start new OTA rollback session: {new_session_id=}") + logger.info(f"start new OTA rollback session: {new_session_id=}") try: logger.info("[rollback] entering...") self._live_ota_status = OTAStatus.ROLLBACKING _OTARollbacker(boot_controller=self.boot_controller).execute() except ota_errors.OTAError as e: - self._live_ota_status = OTAStatus.FAILURE self._on_failure( e, ota_status=OTAStatus.FAILURE, failure_reason=e.get_failure_reason(), failure_type=e.failure_type, ) + + def main( + self, + *, + req_queue: mp_queue.Queue[IPCRequest], + resp_queue: mp_queue.Queue[IPCResponse], + ) -> NoReturn: + """Main loop of ota_core process.""" + _allow_request_after = 0 + while True: + _now = int(time.time()) + try: + request = req_queue.get(timeout=OP_CHECK_INTERVAL) + except Empty: + continue + + if _now < _allow_request_after or self.is_busy: + _err_msg = ( + f"otaclient is busy at {self._live_ota_status} or " + f"request too quickly({_allow_request_after=}), " + f"reject {request}" + ) + logger.warning(_err_msg) + resp_queue.put_nowait( + IPCResponse( + res=IPCResEnum.REJECT_BUSY, + msg=_err_msg, + session_id=request.session_id, + ) + ) + + elif isinstance(request, UpdateRequestV2): + + _update_thread = threading.Thread( + target=self.update, + args=[request], + daemon=True, + name="ota_update_executor", + ) + _update_thread.start() + + resp_queue.put_nowait( + IPCResponse( + res=IPCResEnum.ACCEPT, + session_id=request.session_id, + ) + ) + _allow_request_after = _now + HOLD_REQ_HANDLING_ON_ACK_REQUEST + + elif ( + isinstance(request, RollbackRequestV2) + and self._live_ota_status == OTAStatus.SUCCESS + ): + _rollback_thread = threading.Thread( + target=self.rollback, + args=[request], + daemon=True, + name="ota_rollback_executor", + ) + _rollback_thread.start() + + resp_queue.put_nowait( + IPCResponse( + res=IPCResEnum.ACCEPT, + session_id=request.session_id, + ) + ) + _allow_request_after = _now + HOLD_REQ_HANDLING_ON_ACK_REQUEST + else: + + _err_msg = f"request is invalid: {request=}, {self._live_ota_status=}" + logger.error(_err_msg) + resp_queue.put_nowait( + IPCResponse( + res=IPCResEnum.REJECT_OTHER, + msg=_err_msg, + session_id=request.session_id, + ) + ) + + +def _sign_handler(signal_value, frame) -> NoReturn: + print(f"ota_core process receives {signal_value=}, exits ...") + sys.exit(1) + + +def ota_core_process( + *, + shm_writer_factory: Callable[[], SharedOTAClientStatusWriter], + ecu_status_flags: MultipleECUStatusFlags, + op_queue: mp_queue.Queue[IPCRequest], + resp_queue: mp_queue.Queue[IPCResponse], + max_traceback_size: int, # in bytes +): + from otaclient._logging import configure_logging + from otaclient.configs.cfg import proxy_info + from otaclient.ota_core import OTAClient + + signal.signal(signal.SIGTERM, _sign_handler) + configure_logging() + + shm_writer = shm_writer_factory() + + _local_status_report_queue = Queue() + _status_monitor = OTAClientStatusCollector( + msg_queue=_local_status_report_queue, + shm_status=shm_writer, + max_traceback_size=max_traceback_size, + ) + _status_monitor.start() + + _ota_core = OTAClient( + ecu_status_flags=ecu_status_flags, + proxy=proxy_info.get_proxy_for_local_ota(), + status_report_queue=_local_status_report_queue, + ) + _ota_core.main(req_queue=op_queue, resp_queue=resp_queue) diff --git a/tests/conftest.py b/tests/conftest.py index 4697c4960..1df11b532 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -262,12 +262,21 @@ def proxy_info_fixture(tmp_path: Path) -> ProxyInfo: return parse_proxy_info(_yaml_f) +MAX_TRACEBACK_SIZE = 2048 + + @pytest.fixture(scope="class") -def ota_status_collector() -> ( - Generator[tuple[OTAClientStatusCollector, Queue[StatusReport]], Any, None] -): +def ota_status_collector( + class_mocker: pytest_mock.MockerFixture, +) -> Generator[tuple[OTAClientStatusCollector, Queue[StatusReport]], Any, None]: + _shm_mock = class_mocker.MagicMock() + _report_queue: Queue[StatusReport] = Queue() - _status_collector = OTAClientStatusCollector(_report_queue) + _status_collector = OTAClientStatusCollector( + msg_queue=_report_queue, + shm_status=_shm_mock, + max_traceback_size=MAX_TRACEBACK_SIZE, + ) _collector_thread = _status_collector.start() try: diff --git a/tests/test_otaclient/test_create_standby.py b/tests/test_otaclient/test_create_standby.py index d2c7e1f2e..b0abbbec8 100644 --- a/tests/test_otaclient/test_create_standby.py +++ b/tests/test_otaclient/test_create_standby.py @@ -37,7 +37,7 @@ from otaclient.configs.cfg import cfg as otaclient_cfg from otaclient.create_standby import common, rebuild_mode from otaclient.create_standby.rebuild_mode import RebuildMode -from otaclient.ota_core import OTAClientControlFlags, _OTAUpdater +from otaclient.ota_core import _OTAUpdater from tests.conftest import TestConfiguration as cfg from tests.utils import SlotMeta, compare_dir @@ -105,16 +105,24 @@ def test_update_with_rebuild_mode( mocker: MockerFixture, ): status_collector, status_report_queue = ota_status_collector - - # ------ execution ------ # - otaclient_control_flags = typing.cast( - OTAClientControlFlags, mocker.MagicMock(spec=OTAClientControlFlags) + ecu_status_flags = mocker.MagicMock() + ecu_status_flags.any_requires_network.is_set = mocker.MagicMock( + return_value=False ) - otaclient_control_flags._can_reboot = _can_reboot = mocker.MagicMock() - _can_reboot.is_set = mocker.MagicMock(return_value=True) + # ------ execution ------ # ca_store = load_ca_cert_chains(cfg.CERTS_DIR) + # update OTA status to update and assign session_id before OTAUpdate initialized + status_report_queue.put_nowait( + StatusReport( + payload=OTAStatusChangeReport( + new_ota_status=OTAStatus.UPDATING, + ), + session_id=self.SESSION_ID, + ) + ) + _updater = _OTAUpdater( version=cfg.UPDATE_VERSION, raw_url_base=cfg.OTA_IMAGE_URL, @@ -122,29 +130,22 @@ def test_update_with_rebuild_mode( ca_chains_store=ca_store, upper_otaproxy=None, boot_controller=self._boot_control, + ecu_status_flags=ecu_status_flags, create_standby_cls=RebuildMode, - control_flags=otaclient_control_flags, status_report_queue=status_report_queue, session_id=self.SESSION_ID, ) _updater._process_persistents = persist_handler = mocker.MagicMock() - # update OTA status to update and assign session_id before execution - status_report_queue.put_nowait( - StatusReport( - payload=OTAStatusChangeReport( - new_ota_status=OTAStatus.UPDATING, - ), - session_id=self.SESSION_ID, - ) - ) + time.sleep(2) + + # ------ execution ------ # _updater.execute() - time.sleep(2) # wait for downloader to record stats # ------ assertions ------ # persist_handler.assert_called_once() - otaclient_control_flags._can_reboot.is_set.assert_called_once() + ecu_status_flags.any_requires_network.is_set.assert_called_once() # --- ensure the update stats are collected collected_status = status_collector.otaclient_status assert collected_status diff --git a/tests/test_otaclient/test_grpc/test_api_v2/test_ecu_status.py b/tests/test_otaclient/test_grpc/test_api_v2/test_ecu_status.py index 1da659250..8e3081f8d 100644 --- a/tests/test_otaclient/test_grpc/test_api_v2/test_ecu_status.py +++ b/tests/test_otaclient/test_grpc/test_api_v2/test_ecu_status.py @@ -17,6 +17,7 @@ import asyncio import logging +import threading from typing import Any import pytest @@ -25,6 +26,7 @@ from otaclient import __version__ from otaclient import _types as _internal_types +from otaclient._types import MultipleECUStatusFlags from otaclient.configs import DefaultOTAClientConfigs from otaclient.configs._ecu_info import ECUInfo from otaclient.grpc.api_v2.servicer import ECUStatusStorage @@ -49,7 +51,13 @@ async def setup_test(self, mocker: MockerFixture, ecu_info_fixture: ECUInfo): mocker.patch(f"{ECU_STATUS_MODULE}.ecu_info", ecu_info) # init and setup the ecu_storage - self.ecu_storage = ECUStatusStorage() + # NOTE: here we use threading.Event instead + self.ecu_status_flags = ecu_status_flags = MultipleECUStatusFlags( + any_in_update=threading.Event(), # type: ignore[assignment] + any_requires_network=threading.Event(), # type: ignore[assignment] + all_success=threading.Event(), # type: ignore[assignment] + ) + self.ecu_storage = ECUStatusStorage(ecu_status_flags=ecu_status_flags) _mocked_otaclient_cfg = DefaultOTAClientConfigs() # NOTE: decrease the interval for faster testing @@ -326,7 +334,7 @@ async def test_export( compare_message(exported, expected) @pytest.mark.parametrize( - "local_ecu_status,sub_ecus_status,properties_dict", + "local_ecu_status,sub_ecus_status,properties_dict,flags_status", ( # case 1: ( @@ -368,8 +376,12 @@ async def test_export( "in_update_ecus_id": {"autoware", "p2"}, "in_update_child_ecus_id": {"p2"}, "failed_ecus_id": {"p1"}, - "any_requires_network": True, "success_ecus_id": set(), + }, + # ecu_status_flags + { + "any_in_update": True, + "any_requires_network": True, "all_success": False, }, ), @@ -413,8 +425,12 @@ async def test_export( "in_update_ecus_id": {"p2"}, "in_update_child_ecus_id": {"p2"}, "failed_ecus_id": {"p1"}, - "any_requires_network": True, "success_ecus_id": {"autoware"}, + }, + # ecu_status_flags + { + "any_in_update": True, + "any_requires_network": True, "all_success": False, }, ), @@ -425,6 +441,7 @@ async def test_overall_ecu_status_report_generation( local_ecu_status: _internal_types.OTAClientStatus, sub_ecus_status: list[api_types.StatusResponse], properties_dict: dict[str, Any], + flags_status: dict[str, bool], ): # --- prepare --- # await self.ecu_storage.update_from_local_ecu(local_ecu_status) @@ -438,8 +455,11 @@ async def test_overall_ecu_status_report_generation( for k, v in properties_dict.items(): assert getattr(self.ecu_storage, k) == v, f"status_report attr {k} mismatch" + for k, v in flags_status.items(): + assert getattr(self.ecu_status_flags, k).is_set() == v + @pytest.mark.parametrize( - "local_ecu_status,sub_ecus_status,ecus_accept_update_request,properties_dict", + "local_ecu_status,sub_ecus_status,ecus_accept_update_request,properties_dict,flags_status", ( # case 1: # There is FAILED/UPDATING ECUs existed in the cluster. @@ -486,8 +506,12 @@ async def test_overall_ecu_status_report_generation( "in_update_ecus_id": {"autoware", "p2"}, "in_update_child_ecus_id": {"p2"}, "failed_ecus_id": {"p1"}, - "any_requires_network": True, "success_ecus_id": set(), + }, + # ecu_status_flags + { + "any_in_update": True, + "any_requires_network": True, "all_success": False, }, ), @@ -534,8 +558,12 @@ async def test_overall_ecu_status_report_generation( "in_update_ecus_id": {"autoware", "p1"}, "in_update_child_ecus_id": {"p1"}, "failed_ecus_id": set(), - "any_requires_network": True, "success_ecus_id": {"p2"}, + }, + # ecu_status_flags + { + "any_in_update": True, + "any_requires_network": True, "all_success": False, }, ), @@ -547,6 +575,7 @@ async def test_on_receive_update_request( sub_ecus_status: list[api_types.StatusResponse], ecus_accept_update_request: list[str], properties_dict: dict[str, Any], + flags_status: dict[str, bool], mocker: pytest_mock.MockerFixture, ): # --- prepare --- # @@ -571,7 +600,9 @@ async def test_on_receive_update_request( # --- assertion --- # for k, v in properties_dict.items(): assert getattr(self.ecu_storage, k) == v, f"status_report attr {k} mismatch" - assert self.ecu_storage.active_ota_update_present.is_set() + + for k, v in flags_status.items(): + assert getattr(self.ecu_status_flags, k).is_set() == v async def test_polling_waiter_switching_from_idling_to_active(self): """Waiter should immediately return if active_ota_update_present is set.""" @@ -579,9 +610,9 @@ async def test_polling_waiter_switching_from_idling_to_active(self): async def _event_setter(): await asyncio.sleep(_sleep_time) - self.ecu_storage.active_ota_update_present.set() + self.ecu_status_flags.any_in_update.set() - self.ecu_storage.active_ota_update_present.clear() + self.ecu_status_flags.any_in_update.clear() _waiter = self.ecu_storage.get_polling_waiter() asyncio.create_task(_event_setter()) # waiter should return on active_ota_update_present is set, instead of waiting the diff --git a/tests/test_otaclient/test_grpc/test_api_v2/test_servicer.py b/tests/test_otaclient/test_grpc/test_api_v2/test_servicer.py deleted file mode 100644 index c9a96cb52..000000000 --- a/tests/test_otaclient/test_grpc/test_api_v2/test_servicer.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Set - -import pytest -from pytest_mock import MockerFixture - -from otaclient.configs._ecu_info import ECUInfo -from otaclient.configs._proxy_info import ProxyInfo -from otaclient.grpc.api_v2 import ecu_status, servicer -from otaclient.grpc.api_v2.ecu_tracker import ECUTracker -from otaclient.grpc.api_v2.servicer import ( - ECUStatusStorage, - OTAClientAPIServicer, - OTAProxyLauncher, -) -from otaclient.grpc.api_v2.types import convert_from_apiv2_update_request -from otaclient.ota_core import OTAClient, OTAClientControlFlags -from otaclient_api.v2 import types as api_types -from otaclient_api.v2.api_caller import OTAClientCall -from tests.utils import compare_message - -logger = logging.getLogger(__name__) - -SERVICER_MODULE = servicer.__name__ -ECU_STATUS_MODULE = ecu_status.__name__ - - -class TestOTAClientServiceStub: - POLLING_INTERVAL = 1 - ENSURE_NEXT_CHECKING_ROUND = 1.2 - - @staticmethod - async def _subecu_accept_update_request( - ecu_id, *args, **kwargs - ) -> api_types.UpdateResponse: - return api_types.UpdateResponse( - ecu=[ - api_types.UpdateResponseEcu( - ecu_id=ecu_id, result=api_types.FailureType.NO_FAILURE - ) - ] - ) - - @pytest.fixture(autouse=True) - async def setup_test( - self, - mocker: MockerFixture, - ecu_info_fixture: ECUInfo, - proxy_info_fixture: ProxyInfo, - ): - threadpool = ThreadPoolExecutor() - - # ------ mock and patch ------ # - self.ecu_info = ecu_info = ecu_info_fixture - mocker.patch(f"{SERVICER_MODULE}.ecu_info", ecu_info) - - # NOTE: decrease the interval to speed up testing - # (used by _otaproxy_lifecycle_managing/_otaclient_control_flags_managing task) - mocker.patch( - f"{ECU_STATUS_MODULE}.ACTIVE_POLLING_INTERVAL", self.POLLING_INTERVAL - ) - mocker.patch( - f"{ECU_STATUS_MODULE}.IDLE_POLLING_INTERVAL", self.POLLING_INTERVAL - ) - - # ------ init and setup the ecu_storage ------ # - self.control_flag = OTAClientControlFlags() - self.ecu_storage = ECUStatusStorage() - self.ecu_storage.on_ecus_accept_update_request = mocker.AsyncMock() - # NOTE: disable internal overall ecu status generation task as we - # will manipulate the values by ourselves. - self.ecu_storage._debug_properties_update_shutdown_event.set() - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) # ensure the task stopping - - # --- mocker --- # - self.otaclient_inst = mocker.MagicMock(spec=OTAClient) - type(self.otaclient_inst).started = mocker.PropertyMock(return_value=True) - type(self.otaclient_inst).is_busy = mocker.PropertyMock(return_value=False) - - self.ecu_status_tracker = mocker.MagicMock(spec=ECUTracker) - self.otaproxy_launcher = mocker.MagicMock(spec=OTAProxyLauncher) - # mock OTAClientCall, make update_call return success on any update dispatches to subECUs - self.otaclient_call = mocker.AsyncMock(spec=OTAClientCall) - self.otaclient_call.update_call = mocker.AsyncMock( - wraps=self._subecu_accept_update_request - ) - - # ------ mock and patch proxy_info ------ # - self.proxy_info = proxy_info = proxy_info_fixture - mocker.patch(f"{SERVICER_MODULE}.proxy_info", proxy_info) - - # --- patching and mocking --- # - mocker.patch( - f"{SERVICER_MODULE}.ECUStatusStorage", - mocker.MagicMock(return_value=self.ecu_storage), - ) - mocker.patch( - f"{SERVICER_MODULE}.OTAProxyLauncher", - mocker.MagicMock(return_value=self.otaproxy_launcher), - ) - mocker.patch(f"{SERVICER_MODULE}.OTAClientCall", self.otaclient_call) - - # --- start the OTAClientServiceStub --- # - self.otaclient_service_stub = OTAClientAPIServicer( - otaclient_inst=self.otaclient_inst, - ecu_status_storage=self.ecu_storage, - control_flag=self.control_flag, - executor=threadpool, - ) - - try: - yield - finally: - self.otaclient_service_stub._debug_status_checking_shutdown_event.set() - threadpool.shutdown() - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) # ensure shutdown - - async def test__otaproxy_lifecycle_managing(self): - """ - otaproxy startup/shutdown is only controlled by any_requires_network - in overall ECU status report. - """ - # ------ otaproxy startup ------- # - # --- prepartion --- # - self.otaproxy_launcher.is_running = False - self.ecu_storage.any_requires_network = True - - # --- wait for execution --- # - # wait for _otaproxy_lifecycle_managing to launch - # the otaproxy on overall ecu status changed - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) - - # --- assertion --- # - self.otaproxy_launcher.start.assert_called_once() - - # ------ otaproxy shutdown ------ # - # --- prepartion --- # - # set the OTAPROXY_SHUTDOWN_DELAY to allow start/stop in single test - self.otaclient_service_stub.OTAPROXY_SHUTDOWN_DELAY = 1 # type: ignore - self.otaproxy_launcher.is_running = True - self.ecu_storage.any_requires_network = False - - # --- wait for execution --- # - # wait for _otaproxy_lifecycle_managing to shutdown - # the otaproxy on overall ecu status changed - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) - - # --- assertion --- # - self.otaproxy_launcher.stop.assert_called_once() - - # ---- cache dir cleanup --- # - # only cleanup cache dir on all ECUs in SUCCESS ota_status - self.ecu_storage.any_requires_network = False - self.ecu_storage.all_success = True - self.otaproxy_launcher.is_running = False - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) - - # --- assertion --- # - self.otaproxy_launcher.cleanup_cache_dir.assert_called_once() - - async def test__otaclient_control_flags_managing(self): - otaclient_control_flags = self.control_flag - # there are child ECUs in UPDATING - self.ecu_storage.in_update_child_ecus_id = {"p1", "p2"} - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) - assert not otaclient_control_flags._can_reboot.is_set() - - # no more child ECUs in UPDATING - self.ecu_storage.in_update_child_ecus_id = set() - await asyncio.sleep(self.ENSURE_NEXT_CHECKING_ROUND) - assert otaclient_control_flags._can_reboot.is_set() - - @pytest.mark.parametrize( - "update_request, update_target_ids, expected", - ( - # update request for autoware, p1 ecus - ( - api_types.UpdateRequest( - ecu=[ - api_types.UpdateRequestEcu( - ecu_id="autoware", - version="789.x", - url="url", - cookies="cookies", - ), - api_types.UpdateRequestEcu( - ecu_id="p1", - version="789.x", - url="url", - cookies="cookies", - ), - ] - ), - {"autoware", "p1"}, - # NOTE: order matters! - # update request dispatching to subECUs happens first, - # and then to the local ECU. - api_types.UpdateResponse( - ecu=[ - api_types.UpdateResponseEcu( - ecu_id="p1", - result=api_types.FailureType.NO_FAILURE, - ), - api_types.UpdateResponseEcu( - ecu_id="autoware", - result=api_types.FailureType.NO_FAILURE, - ), - ] - ), - ), - # update only p2 - ( - api_types.UpdateRequest( - ecu=[ - api_types.UpdateRequestEcu( - ecu_id="p2", - version="789.x", - url="url", - cookies="cookies", - ), - ] - ), - {"p2"}, - api_types.UpdateResponse( - ecu=[ - api_types.UpdateResponseEcu( - ecu_id="p2", - result=api_types.FailureType.NO_FAILURE, - ), - ] - ), - ), - ), - ) - async def test_update_normal( - self, - update_request: api_types.UpdateRequest, - update_target_ids: Set[str], - expected: api_types.UpdateResponse, - ): - # --- execution --- # - resp = await self.otaclient_service_stub.update(update_request) - - # --- assertion --- # - compare_message(resp, expected) - - self.otaclient_call.update_call.assert_called() - self.ecu_storage.on_ecus_accept_update_request.assert_called_once_with( # type: ignore - update_target_ids - ) - # assert otaclient_inst receives the update request if we have update request for self ECU - if update_request.if_contains_ecu("autoware"): - _update_request_ecu = update_request.find_ecu("autoware") - assert _update_request_ecu - - self.otaclient_inst.update.assert_called_once_with( - convert_from_apiv2_update_request(_update_request_ecu) - ) - - async def test_update_local_ecu_busy( - self, - mocker: MockerFixture, - ): - # --- preparation --- # - is_busy_mock = mocker.PropertyMock(return_value=True) # is busy - type(self.otaclient_inst).is_busy = is_busy_mock - - update_request_ecu = api_types.UpdateRequestEcu( - ecu_id="autoware", version="version", url="url", cookies="cookies" - ) - - # --- execution --- # - await self.otaclient_service_stub.update( - api_types.UpdateRequest(ecu=[update_request_ecu]) - ) - - # --- assertion --- # - # assert otaclient_inst doesn't receive the update request - self.otaclient_inst.update.assert_not_called() diff --git a/tests/test_otaclient/test_main.py b/tests/test_otaclient/test_main.py deleted file mode 100644 index 14b96cbc0..000000000 --- a/tests/test_otaclient/test_main.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import time -from multiprocessing import Process -from pathlib import Path - -import pytest -from pytest import LogCaptureFixture -from pytest_mock import MockerFixture - -from otaclient.configs.cfg import cfg as otaclient_cfg - -FIRST_LINE_LOG = "d3b6bdb | 2021-10-27 09:36:48 +0900 | Initial commit" -MAIN_MODULE = "otaclient.main" -UTILS_MODULE = "otaclient.utils" - - -class TestMain: - @pytest.fixture(autouse=True) - def patch_main(self, mocker: MockerFixture, tmp_path: Path): - mocker.patch(f"{MAIN_MODULE}.launch_otaclient_grpc_server") - mocker.patch("otaclient._logging.configure_logging") - - self._sys_exit_mocker = mocker.MagicMock(side_effect=ValueError) - mocker.patch(f"{UTILS_MODULE}.sys.exit", self._sys_exit_mocker) - - @pytest.fixture - def background_process(self): - def _waiting(): - time.sleep(1234) - - _p = Process(target=_waiting) - try: - _p.start() - Path(otaclient_cfg.OTACLIENT_PID_FILE).write_text(f"{_p.pid}") - yield _p.pid - finally: - _p.kill() - - def test_main(self, caplog: LogCaptureFixture): - from otaclient.main import main - - main() - assert caplog.records[0].msg == "started" - assert Path(otaclient_cfg.OTACLIENT_PID_FILE).read_text() == f"{os.getpid()}" - - def test_with_other_otaclient_started(self, background_process): - from otaclient.main import main - - _other_pid = f"{background_process}" - with pytest.raises(ValueError): - main() - self._sys_exit_mocker.assert_called_once() - assert Path(otaclient_cfg.OTACLIENT_PID_FILE).read_text() == _other_pid diff --git a/tests/test_otaclient/test_ota_core.py b/tests/test_otaclient/test_ota_core.py index 0c283a5ad..ffaab77b0 100644 --- a/tests/test_otaclient/test_ota_core.py +++ b/tests/test_otaclient/test_ota_core.py @@ -38,7 +38,7 @@ from otaclient.create_standby import StandbySlotCreatorProtocol from otaclient.create_standby.common import DeltaBundle, RegularDelta from otaclient.errors import OTAErrorRecoverable -from otaclient.ota_core import OTAClient, OTAClientControlFlags, _OTAUpdater +from otaclient.ota_core import OTAClient, _OTAUpdater from tests.conftest import TestConfiguration as cfg from tests.utils import SlotMeta @@ -158,18 +158,26 @@ def test_otaupdater( self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]], mocker: pytest_mock.MockerFixture, - ): - from otaclient.ota_core import OTAClientControlFlags, _OTAUpdater - + ) -> None: _, report_queue = ota_status_collector + ecu_status_flags = mocker.MagicMock() + ecu_status_flags.any_requires_network.is_set = mocker.MagicMock( + return_value=False + ) # ------ execution ------ # - otaclient_control_flags = mocker.MagicMock(spec=OTAClientControlFlags) - otaclient_control_flags._can_reboot = _can_reboot = mocker.MagicMock() - _can_reboot.is_set = mocker.MagicMock(return_value=True) - ca_store = load_ca_cert_chains(cfg.CERTS_DIR) + # update OTA status to update and assign session_id before execution + report_queue.put_nowait( + StatusReport( + payload=OTAStatusChangeReport( + new_ota_status=OTAStatus.UPDATING, + ), + session_id=self.SESSION_ID, + ) + ) + _updater = _OTAUpdater( version=cfg.UPDATE_VERSION, raw_url_base=cfg.OTA_IMAGE_URL, @@ -178,21 +186,12 @@ def test_otaupdater( boot_controller=self._boot_control, upper_otaproxy=None, create_standby_cls=self._create_standby_cls, - control_flags=otaclient_control_flags, + ecu_status_flags=ecu_status_flags, session_id=self.SESSION_ID, status_report_queue=report_queue, ) _updater._process_persistents = process_persists_handler = mocker.MagicMock() - # update OTA status to update and assign session_id before execution - report_queue.put_nowait( - StatusReport( - payload=OTAStatusChangeReport( - new_ota_status=OTAStatus.UPDATING, - ), - session_id=self.SESSION_ID, - ) - ) _updater.execute() # ------ assertions ------ # @@ -203,7 +202,7 @@ def test_otaupdater( assert _downloaded_files_size == self._delta_bundle.total_download_files_size # assert the control_flags has been waited - otaclient_control_flags._can_reboot.is_set.assert_called_once() + ecu_status_flags.any_requires_network.is_set.assert_called_once() assert _updater.update_version == str(cfg.UPDATE_VERSION) @@ -235,16 +234,22 @@ def mock_setup( mocker: pytest_mock.MockerFixture, ): _, status_report_queue = ota_status_collector + ecu_status_flags = mocker.MagicMock() + ecu_status_flags.any_requires_network.is_set = mocker.MagicMock( + return_value=False + ) # --- mock setup --- # - self.control_flags = mocker.MagicMock(spec=OTAClientControlFlags) + self.control_flags = ecu_status_flags self.ota_updater = mocker.MagicMock(spec=_OTAUpdater) self.boot_controller = mocker.MagicMock(spec=BootControllerProtocol) # patch boot_controller for otaclient initializing self.boot_controller.load_version.return_value = self.CURRENT_FIRMWARE_VERSION - self.boot_controller.get_booted_ota_status.return_value = OTAStatus.SUCCESS + self.boot_controller.get_booted_ota_status = mocker.MagicMock( + return_value=OTAStatus.SUCCESS + ) # patch inject mocked updater mocker.patch(f"{OTA_CORE_MODULE}._OTAUpdater", return_value=self.ota_updater) @@ -254,7 +259,7 @@ def mock_setup( # start otaclient self.ota_client = OTAClient( - control_flags=self.control_flags, + ecu_status_flags=ecu_status_flags, status_report_queue=status_report_queue, ) @@ -265,6 +270,7 @@ def test_update_normal_finished(self): version=self.UPDATE_FIRMWARE_VERSION, url_base=self.OTA_IMAGE_URL, cookies_json=self.UPDATE_COOKIES_JSON, + session_id="test_update_normal_finished", ) ) @@ -283,28 +289,10 @@ def test_update_interrupted(self): version=self.UPDATE_FIRMWARE_VERSION, url_base=self.OTA_IMAGE_URL, cookies_json=self.UPDATE_COOKIES_JSON, + session_id="test_updaste_interrupted", ) ) # --- assertion on interrupted OTA update --- # self.ota_updater.execute.assert_called_once() assert self.ota_client.live_ota_status == OTAStatus.FAILURE - - def test_status_in_update(self, mocker: pytest_mock.MockerFixture): - # --- mock setup --- # - _ota_updater_mocker = mocker.MagicMock(spec=_OTAUpdater) - mocker.patch(f"{OTA_CORE_MODULE}._OTAUpdater", _ota_updater_mocker) - self.ota_client._live_ota_status = OTAStatus.UPDATING - - # --- execution --- # - self.ota_client.update( - request=UpdateRequestV2( - version=self.UPDATE_FIRMWARE_VERSION, - url_base=self.OTA_IMAGE_URL, - cookies_json=self.UPDATE_COOKIES_JSON, - ) - ) - - # --- assertion --- # - # confirm that the OTA update doesn't happen - _ota_updater_mocker.assert_not_called() diff --git a/tests/test_otaclient/test_status_monitor.py b/tests/test_otaclient/test_status_monitor.py index d423d14da..411927e24 100644 --- a/tests/test_otaclient/test_status_monitor.py +++ b/tests/test_otaclient/test_status_monitor.py @@ -21,12 +21,10 @@ import random import time from queue import Queue -from typing import Generator import pytest from otaclient._status_monitor import ( - TERMINATE_SENTINEL, OTAClientStatusCollector, OTAStatusChangeReport, OTAUpdatePhaseChangeReport, @@ -51,24 +49,11 @@ class TestStatusMonitor: DOWNLOAD_NUM = DWONLOAD_SIZE = TOTAL_DOWNLOAD_SIZE = 600 MULTI_PATHS_FILE = MULTI_PATHS_FILE_SIZE = 100 - @pytest.fixture(autouse=True, scope="class") - def msg_queue(self) -> Generator[Queue[StatusReport], None, None]: - _queue = Queue() - yield _queue - - @pytest.fixture(autouse=True, scope="class") - def status_collector(self, msg_queue: Queue[StatusReport]): - status_collector = OTAClientStatusCollector(msg_queue=msg_queue) - _thread = status_collector.start() - try: - yield status_collector - finally: - msg_queue.put_nowait(TERMINATE_SENTINEL) - _thread.join() - def test_otaclient_start( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ): + status_collector, msg_queue = ota_status_collector + _test_failure_reason = "test_no_failure_reason" _test_current_version = "test_current_version" msg_queue.put_nowait( @@ -98,8 +83,10 @@ def test_otaclient_start( assert otaclient_status.failure_reason == _test_failure_reason def test_start_ota_update( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ): + status_collector, msg_queue = ota_status_collector + # ------ execution ------ # msg_queue.put_nowait( StatusReport( @@ -139,8 +126,9 @@ def test_start_ota_update( assert update_meta.update_firmware_version == self.UPDATE_VERSION_FOR_TEST def test_process_metadata( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector # ------ execution ------ # msg_queue.put_nowait( StatusReport( @@ -176,12 +164,16 @@ def test_process_metadata( assert update_progress.downloaded_bytes == self.METADATA_SIZE def test_filter_invalid_session_id( - self, msg_queue: Queue[StatusReport], caplog: pytest.LogCaptureFixture + self, + ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]], + caplog: pytest.LogCaptureFixture, ) -> None: """This test put reports with invalid session_id into the msg_queue. If the filter is working, all the later test methods will not fail. """ + _, msg_queue = ota_status_collector + _invalid_session_id = "invalid_session_id" # put an update meta change report @@ -224,8 +216,9 @@ def test_filter_invalid_session_id( assert all(_record.levelno == logging.WARNING for _record in caplog.records) def test_calculate_delta( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector _now = int(time.time()) # ------ execution ------ # @@ -282,8 +275,9 @@ def test_calculate_delta( assert update_meta.total_remove_files_num == 123 def test_download_ota_files( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector _now = int(time.time()) # ------ execution ------ # @@ -336,8 +330,10 @@ def test_download_ota_files( ) def test_apply_update( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector + _now = int(time.time()) # ------ execution ------ # @@ -374,8 +370,9 @@ def test_apply_update( ) and update_timing.update_apply_start_timestamp == _now def test_post_update( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector _now = int(time.time()) msg_queue.put_nowait( StatusReport( @@ -397,8 +394,10 @@ def test_post_update( ) and update_timing.post_update_start_timestamp == _now def test_finalizing_update( - self, status_collector: OTAClientStatusCollector, msg_queue: Queue[StatusReport] + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, msg_queue = ota_status_collector + _now = int(time.time()) msg_queue.put_nowait( StatusReport( @@ -417,8 +416,9 @@ def test_finalizing_update( assert otaclient_status.update_phase == UpdatePhase.FINALIZING_UPDATE def test_confirm_update_progress( - self, status_collector: OTAClientStatusCollector + self, ota_status_collector: tuple[OTAClientStatusCollector, Queue[StatusReport]] ) -> None: + status_collector, _ = ota_status_collector time.sleep(2) # wait for reports being processed otaclient_status = status_collector.otaclient_status diff --git a/tests/test_otaclient/test_utils.py b/tests/test_otaclient/test_utils.py index cde15e0b6..0ffed4e0f 100644 --- a/tests/test_otaclient/test_utils.py +++ b/tests/test_otaclient/test_utils.py @@ -20,7 +20,7 @@ import pytest -from otaclient.utils import wait_and_log +from otaclient._utils import wait_and_log logger = logging.getLogger(__name__) @@ -42,8 +42,9 @@ def test_wait_and_log(caplog: pytest.LogCaptureFixture): _msg = "ticking flag" wait_and_log( - _flag, + _flag.is_set, _msg, + check_for=True, check_interval=1, log_interval=2, log_func=logger.warning,