Skip to content

Commit

Permalink
Merge branch 'main' into Bodong-Yang-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang authored Nov 27, 2024
2 parents 4687856 + afc2f67 commit 04521b7
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/gen_requirements_txt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
# For more details about this restriction, please refer to:
# https://github.com/peter-evans/create-pull-request/issues/48 and
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
ssh_key: ${{ secrets.DEPLOY_KEY }}
ssh-key: ${{ secrets.DEPLOY_KEY }}
persist-credentials: true

- name: setup python
Expand Down
1 change: 1 addition & 0 deletions docker/test_base/entry_point.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ echo "execute test with coverage"
cd ${TEST_ROOT}
hatch env create dev
hatch run dev:coverage run -m pytest --junit-xml=${OUTPUT_DIR}/pytest.xml ${@:-}
hatch run dev:coverage combine
hatch run dev:coverage xml -o ${OUTPUT_DIR}/coverage.xml
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ extend-exclude = '''(
)'''

[tool.coverage.run]
concurrency = [
"multiprocessing",
"thread",
]
branch = false
relative_files = true
parallel = true
source = [
"otaclient",
"otaclient_api",
Expand Down
207 changes: 207 additions & 0 deletions src/otaclient_common/shm_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2022 TIER IV, INC. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A lib for sharing status between processes.
shared memory layout:
rwlock(1byte) | hmac-sha512 of msg(64bytes) | msg_len(4bytes,big) | msg(<msg_len>bytes)
In which, msg is pickled python object.
"""


from __future__ import annotations

import hashlib
import hmac
import logging
import multiprocessing.shared_memory as mp_shm
import pickle
import time
from typing import Generic

from otaclient_common.typing import T

logger = logging.getLogger(__name__)

DEFAULT_HASH_ALG = "sha512"
DEFAULT_KEY_LEN = hashlib.new(DEFAULT_HASH_ALG).digest_size

RWLOCK_LEN = 1 # byte
PAYLOAD_LEN_BYTES = 4 # bytes

RWLOCK_LOCKED = b"\xab"
RWLOCK_OPEN = b"\x54"


class RWBusy(Exception): ...


class SHA512Verifier:
"""Base class for specifying hash alg related configurations."""

DIGEST_ALG = "sha512"
DIGEST_SIZE = hashlib.new(DIGEST_ALG).digest_size
MIN_ENCAP_MSG_LEN = RWLOCK_LEN + DIGEST_SIZE + PAYLOAD_LEN_BYTES

_key: bytes

def cal_hmac(self, _raw_msg: bytes) -> bytes:
return hmac.digest(key=self._key, msg=_raw_msg, digest=self.DIGEST_ALG)

def verify_msg(self, _raw_msg: bytes, _expected_hmac: bytes) -> bool:
return hmac.compare_digest(
hmac.digest(
key=self._key,
msg=_raw_msg,
digest=self.DIGEST_ALG,
),
_expected_hmac,
)


def _ensure_connect_shm(
name: str, *, max_retry: int, retry_interval: int
) -> mp_shm.SharedMemory:
for _idx in range(max_retry):
try:
return mp_shm.SharedMemory(name=name, create=False)
except Exception as e:
logger.warning(
f"retry #{_idx}: failed to connect to {name=}: {e!r}, keep retrying ..."
)
time.sleep(retry_interval)
raise ValueError(f"failed to connect share memory with {name=}")


class MPSharedStatusReader(SHA512Verifier, Generic[T]):

def __init__(
self,
*,
name: str,
key: bytes,
max_retry: int = 6,
retry_interval: int = 1,
) -> None:
self._shm = shm = _ensure_connect_shm(
name, max_retry=max_retry, retry_interval=retry_interval
)
self.mem_size = size = shm.size
self.msg_max_size = size - self.MIN_ENCAP_MSG_LEN
self._key = key

def atexit(self) -> None:
self._shm.close()

def sync_msg(self) -> T:
"""Get msg from shared memory.
Raises:
RWBusy if rwlock indicates the writer is writing or not yet ready.
ValueError for invalid msg.
"""
buffer = self._shm.buf

# check if we can read
_cursor = 0
rwlock = bytes(buffer[_cursor:RWLOCK_LEN])
if rwlock != RWLOCK_OPEN:
if rwlock == RWLOCK_LOCKED:
raise RWBusy("write in progress, abort")
raise RWBusy("no msg has been written yet")
_cursor += RWLOCK_LEN

# parsing the msg
input_hmac = bytes(buffer[_cursor : _cursor + self.DIGEST_SIZE])
_cursor += self.DIGEST_SIZE

_payload_len_bytes = bytes(buffer[_cursor : _cursor + PAYLOAD_LEN_BYTES])
payload_len = int.from_bytes(_payload_len_bytes, "big", signed=False)
_cursor += PAYLOAD_LEN_BYTES

if payload_len > self.msg_max_size:
raise ValueError(f"invalid msg: {payload_len=} > {self.msg_max_size}")

payload = bytes(buffer[_cursor : _cursor + payload_len])
if self.verify_msg(payload, input_hmac):
return pickle.loads(payload)
raise ValueError("failed to validate input msg")


class MPSharedStatusWriter(SHA512Verifier, Generic[T]):

def __init__(
self,
*,
name: str | None = None,
size: int = 0,
key: bytes,
create: bool = False,
msg_max_size: int | None = None,
max_retry: int = 6,
retry_interval: int = 1,
) -> None:
if create:
_msg_max_size = size - self.MIN_ENCAP_MSG_LEN
if _msg_max_size < 0:
raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}")
self._shm = shm = mp_shm.SharedMemory(name=name, size=size, create=True)
self.mem_size = shm.size

elif name:
self._shm = shm = _ensure_connect_shm(
name, max_retry=max_retry, retry_interval=retry_interval
)
self.mem_size = size = shm.size
_msg_max_size = size - self.MIN_ENCAP_MSG_LEN
if _msg_max_size < 0:
shm.close()
raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}")
else:
raise ValueError("<name> must be specified if <create> is False")

self.name = shm.name
self._key = key
self.msg_max_size = min(_msg_max_size, msg_max_size or float("infinity"))

def atexit(self) -> None:
self._shm.close()

def write_msg(self, obj: T) -> None:
"""Write msg to shared memory.
Raises:
ValueError on invalid msg or exceeding shared memory size.
"""
buffer = self._shm.buf
_pickled = pickle.dumps(obj)
_pickled_len = len(_pickled)

if _pickled_len > self.msg_max_size:
raise ValueError(f"exceed {self.msg_max_size=}: {_pickled_len=}")

msg = b"".join(
[
RWLOCK_LOCKED,
self.cal_hmac(_pickled),
_pickled_len.to_bytes(PAYLOAD_LEN_BYTES, "big", signed=False),
_pickled,
]
)
msg_len = len(msg)
if msg_len > self.mem_size:
raise ValueError(f"{msg_len=} > {self.mem_size=}")

buffer[:msg_len] = msg
buffer[:1] = RWLOCK_OPEN
Loading

0 comments on commit 04521b7

Please sign in to comment.