diff --git a/.github/workflows/gen_requirements_txt.yaml b/.github/workflows/gen_requirements_txt.yaml index 57f27a210..f5afcef8f 100644 --- a/.github/workflows/gen_requirements_txt.yaml +++ b/.github/workflows/gen_requirements_txt.yaml @@ -24,7 +24,7 @@ jobs: # For more details about this restriction, please refer to: # https://github.com/peter-evans/create-pull-request/issues/48 and # https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs - ssh_key: ${{ secrets.DEPLOY_KEY }} + ssh-key: ${{ secrets.DEPLOY_KEY }} persist-credentials: true - name: setup python diff --git a/docker/test_base/entry_point.sh b/docker/test_base/entry_point.sh index 4416da4cc..4ab2779e1 100644 --- a/docker/test_base/entry_point.sh +++ b/docker/test_base/entry_point.sh @@ -13,4 +13,5 @@ echo "execute test with coverage" cd ${TEST_ROOT} hatch env create dev hatch run dev:coverage run -m pytest --junit-xml=${OUTPUT_DIR}/pytest.xml ${@:-} +hatch run dev:coverage combine hatch run dev:coverage xml -o ${OUTPUT_DIR}/coverage.xml diff --git a/pyproject.toml b/pyproject.toml index 966d0884f..c98830edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,8 +107,13 @@ extend-exclude = '''( )''' [tool.coverage.run] +concurrency = [ + "multiprocessing", + "thread", +] branch = false relative_files = true +parallel = true source = [ "otaclient", "otaclient_api", diff --git a/src/otaclient_common/shm_status.py b/src/otaclient_common/shm_status.py new file mode 100644 index 000000000..4dc7e025a --- /dev/null +++ b/src/otaclient_common/shm_status.py @@ -0,0 +1,207 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A lib for sharing status between processes. + +shared memory layout: + +rwlock(1byte) | hmac-sha512 of msg(64bytes) | msg_len(4bytes,big) | msg(bytes) +In which, msg is pickled python object. +""" + + +from __future__ import annotations + +import hashlib +import hmac +import logging +import multiprocessing.shared_memory as mp_shm +import pickle +import time +from typing import Generic + +from otaclient_common.typing import T + +logger = logging.getLogger(__name__) + +DEFAULT_HASH_ALG = "sha512" +DEFAULT_KEY_LEN = hashlib.new(DEFAULT_HASH_ALG).digest_size + +RWLOCK_LEN = 1 # byte +PAYLOAD_LEN_BYTES = 4 # bytes + +RWLOCK_LOCKED = b"\xab" +RWLOCK_OPEN = b"\x54" + + +class RWBusy(Exception): ... + + +class SHA512Verifier: + """Base class for specifying hash alg related configurations.""" + + DIGEST_ALG = "sha512" + DIGEST_SIZE = hashlib.new(DIGEST_ALG).digest_size + MIN_ENCAP_MSG_LEN = RWLOCK_LEN + DIGEST_SIZE + PAYLOAD_LEN_BYTES + + _key: bytes + + def cal_hmac(self, _raw_msg: bytes) -> bytes: + return hmac.digest(key=self._key, msg=_raw_msg, digest=self.DIGEST_ALG) + + def verify_msg(self, _raw_msg: bytes, _expected_hmac: bytes) -> bool: + return hmac.compare_digest( + hmac.digest( + key=self._key, + msg=_raw_msg, + digest=self.DIGEST_ALG, + ), + _expected_hmac, + ) + + +def _ensure_connect_shm( + name: str, *, max_retry: int, retry_interval: int +) -> mp_shm.SharedMemory: + for _idx in range(max_retry): + try: + return mp_shm.SharedMemory(name=name, create=False) + except Exception as e: + logger.warning( + f"retry #{_idx}: failed to connect to {name=}: {e!r}, keep retrying ..." + ) + time.sleep(retry_interval) + raise ValueError(f"failed to connect share memory with {name=}") + + +class MPSharedStatusReader(SHA512Verifier, Generic[T]): + + def __init__( + self, + *, + name: str, + key: bytes, + max_retry: int = 6, + retry_interval: int = 1, + ) -> None: + self._shm = shm = _ensure_connect_shm( + name, max_retry=max_retry, retry_interval=retry_interval + ) + self.mem_size = size = shm.size + self.msg_max_size = size - self.MIN_ENCAP_MSG_LEN + self._key = key + + def atexit(self) -> None: + self._shm.close() + + def sync_msg(self) -> T: + """Get msg from shared memory. + + Raises: + RWBusy if rwlock indicates the writer is writing or not yet ready. + ValueError for invalid msg. + """ + buffer = self._shm.buf + + # check if we can read + _cursor = 0 + rwlock = bytes(buffer[_cursor:RWLOCK_LEN]) + if rwlock != RWLOCK_OPEN: + if rwlock == RWLOCK_LOCKED: + raise RWBusy("write in progress, abort") + raise RWBusy("no msg has been written yet") + _cursor += RWLOCK_LEN + + # parsing the msg + input_hmac = bytes(buffer[_cursor : _cursor + self.DIGEST_SIZE]) + _cursor += self.DIGEST_SIZE + + _payload_len_bytes = bytes(buffer[_cursor : _cursor + PAYLOAD_LEN_BYTES]) + payload_len = int.from_bytes(_payload_len_bytes, "big", signed=False) + _cursor += PAYLOAD_LEN_BYTES + + if payload_len > self.msg_max_size: + raise ValueError(f"invalid msg: {payload_len=} > {self.msg_max_size}") + + payload = bytes(buffer[_cursor : _cursor + payload_len]) + if self.verify_msg(payload, input_hmac): + return pickle.loads(payload) + raise ValueError("failed to validate input msg") + + +class MPSharedStatusWriter(SHA512Verifier, Generic[T]): + + def __init__( + self, + *, + name: str | None = None, + size: int = 0, + key: bytes, + create: bool = False, + msg_max_size: int | None = None, + max_retry: int = 6, + retry_interval: int = 1, + ) -> None: + if create: + _msg_max_size = size - self.MIN_ENCAP_MSG_LEN + if _msg_max_size < 0: + raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") + self._shm = shm = mp_shm.SharedMemory(name=name, size=size, create=True) + self.mem_size = shm.size + + elif name: + self._shm = shm = _ensure_connect_shm( + name, max_retry=max_retry, retry_interval=retry_interval + ) + self.mem_size = size = shm.size + _msg_max_size = size - self.MIN_ENCAP_MSG_LEN + if _msg_max_size < 0: + shm.close() + raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") + else: + raise ValueError(" must be specified if is False") + + self.name = shm.name + self._key = key + self.msg_max_size = min(_msg_max_size, msg_max_size or float("infinity")) + + def atexit(self) -> None: + self._shm.close() + + def write_msg(self, obj: T) -> None: + """Write msg to shared memory. + + Raises: + ValueError on invalid msg or exceeding shared memory size. + """ + buffer = self._shm.buf + _pickled = pickle.dumps(obj) + _pickled_len = len(_pickled) + + if _pickled_len > self.msg_max_size: + raise ValueError(f"exceed {self.msg_max_size=}: {_pickled_len=}") + + msg = b"".join( + [ + RWLOCK_LOCKED, + self.cal_hmac(_pickled), + _pickled_len.to_bytes(PAYLOAD_LEN_BYTES, "big", signed=False), + _pickled, + ] + ) + msg_len = len(msg) + if msg_len > self.mem_size: + raise ValueError(f"{msg_len=} > {self.mem_size=}") + + buffer[:msg_len] = msg + buffer[:1] = RWLOCK_OPEN diff --git a/tests/test_otaclient_common/test_shm_status.py b/tests/test_otaclient_common/test_shm_status.py new file mode 100644 index 000000000..2d1631d31 --- /dev/null +++ b/tests/test_otaclient_common/test_shm_status.py @@ -0,0 +1,186 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import multiprocessing as mp +import multiprocessing.shared_memory as mp_shm +import multiprocessing.synchronize as mp_sync +import secrets +import time +from dataclasses import dataclass +from functools import partial + +import pytest + +from otaclient_common.shm_status import ( + DEFAULT_KEY_LEN, + MPSharedStatusReader, + MPSharedStatusWriter, + RWBusy, +) + + +@dataclass +class OuterMsg: + _inner_msg: InnerMsg + + +@dataclass +class InnerMsg: + i_int: int + i_str: str + + +class MsgReader(MPSharedStatusReader[OuterMsg]): ... + + +class MsgWriter(MPSharedStatusWriter[OuterMsg]): ... + + +DATA_ENTRIES_NUM = 20 +_TEST_DATA = { + _idx: OuterMsg( + InnerMsg( + i_int=_idx, + i_str=str(_idx), + ) + ) + for _idx in range(DATA_ENTRIES_NUM) +} +SHM_SIZE = 1024 + + +def writer_process( + shm_name: str, + key: bytes, + *, + interval: float, + write_all_flag: mp_sync.Event, +): + _shm_writer = MsgWriter(name=shm_name, key=key) + + for _, _entry in _TEST_DATA.items(): + _shm_writer.write_msg(_entry) + time.sleep(interval) + write_all_flag.set() + + +def read_slow_process( + shm_name: str, + key: bytes, + *, + interval: float, + success_flag: mp_sync.Event, +): + """Reader is slower than writer, we only need to ensure reader can read the latest written msg.""" + _shm_reader = MsgReader(name=shm_name, key=key) + + while True: + time.sleep(interval) + try: + _msg = _shm_reader.sync_msg() + except RWBusy: + continue + + if _msg._inner_msg.i_int == DATA_ENTRIES_NUM - 1: + return success_flag.set() + + +def read_fast_process( + shm_name: str, + key: bytes, + *, + interval: float, + success_flag: mp_sync.Event, +): + """Reader is faster than writer, we need to ensure all the msgs are read.""" + _shm_reader = MsgReader(name=shm_name, key=key) + _read = [False for _ in range(DATA_ENTRIES_NUM)] + + while True: + time.sleep(interval) + try: + _msg = _shm_reader.sync_msg() + except RWBusy: + continue + + _read[_msg._inner_msg.i_int] = True + if all(_read): + return success_flag.set() + + +WRITE_INTERVAL = 0.1 +READ_SLOW_INTERVAL = 0.5 +READ_FAST_INTERVAL = 0.01 + + +@pytest.mark.parametrize( + "reader_func, read_interval, timeout", + ( + ( + read_fast_process, + READ_FAST_INTERVAL, + WRITE_INTERVAL * DATA_ENTRIES_NUM + 3, + ), + ( + read_slow_process, + READ_SLOW_INTERVAL, + WRITE_INTERVAL * DATA_ENTRIES_NUM + 3, + ), + ), +) +def test_shm_status_read_fast(reader_func, read_interval, timeout): + _shm = mp_shm.SharedMemory(size=SHM_SIZE, create=True) + _mp_ctx = mp.get_context("spawn") + _key = secrets.token_bytes(DEFAULT_KEY_LEN) + + _write_all_flag = _mp_ctx.Event() + _success_flag = _mp_ctx.Event() + + _writer_p = _mp_ctx.Process( + target=partial( + writer_process, + shm_name=_shm.name, + key=_key, + interval=WRITE_INTERVAL, + write_all_flag=_write_all_flag, + ) + ) + _reader_p = _mp_ctx.Process( + target=partial( + reader_func, + shm_name=_shm.name, + key=_key, + interval=read_interval, + success_flag=_success_flag, + ) + ) + _writer_p.start() + _reader_p.start() + + time.sleep(timeout) + try: + assert _write_all_flag.is_set(), "writer timeout finish up writing" + assert _success_flag.is_set(), "reader failed to read all msg" + finally: + _writer_p.terminate() + _writer_p.join() + + _reader_p.terminate() + _reader_p.join() + + _shm.close() + _shm.unlink()