Skip to content

Commit

Permalink
shm_status: move hmac related implementation to SHA512Verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed Nov 27, 2024
1 parent 1c26ce7 commit c3a9819
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions src/otaclient_common/shm_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,24 @@ class RWBusy(Exception): ...
class SHA512Verifier:
"""Base class for specifying hash alg related configurations."""

digest_alg = DEFAULT_HASH_ALG
digest_size = hashlib.new(digest_alg).digest_size
min_encap_msg_len = RWLOCK_LEN + digest_size + PAYLOAD_LEN_BYTES
DIGEST_ALG = DEFAULT_HASH_ALG
DIGEST_SIZE = hashlib.new(DIGEST_ALG).digest_size
MIN_ENCAP_MSG_LEN = RWLOCK_LEN + DIGEST_SIZE + PAYLOAD_LEN_BYTES

_key: bytes

def cal_hmac(self, _raw_msg: bytes) -> bytes:
return hmac.digest(key=self._key, msg=_raw_msg, digest=self.DIGEST_ALG)

def verify_msg(self, _raw_msg: bytes, _expected_hmac: bytes) -> bool:
return hmac.compare_digest(
hmac.digest(
key=self._key,
msg=_raw_msg,
digest=self.DIGEST_ALG,
),
_expected_hmac,
)


class MPSharedStatusReader(SHA512Verifier, Generic[T]):
Expand All @@ -78,7 +93,7 @@ def __init__(
raise ValueError(f"failed to connect share memory with {name=}")

self.mem_size = size = shm.size
self.msg_max_size = size - self.min_encap_msg_len
self.msg_max_size = size - self.MIN_ENCAP_MSG_LEN
self._key = key

def atexit(self) -> None:
Expand All @@ -103,8 +118,8 @@ def sync_msg(self) -> T:
_cursor += RWLOCK_LEN

# parsing the msg
input_hmac = bytes(buffer[_cursor : _cursor + self.digest_size])
_cursor += self.digest_size
input_hmac = bytes(buffer[_cursor : _cursor + self.DIGEST_SIZE])
_cursor += self.DIGEST_SIZE

_payload_len_bytes = bytes(buffer[_cursor : _cursor + PAYLOAD_LEN_BYTES])
payload_len = int.from_bytes(_payload_len_bytes, "big", signed=False)
Expand All @@ -114,9 +129,7 @@ def sync_msg(self) -> T:
raise ValueError(f"invalid msg: {payload_len=} > {self.msg_max_size}")

payload = bytes(buffer[_cursor : _cursor + payload_len])
payload_hmac = hmac.digest(key=self._key, msg=payload, digest=self.digest_alg)

if hmac.compare_digest(payload_hmac, input_hmac):
if self.verify_msg(payload, input_hmac):
return pickle.loads(payload)
raise ValueError("failed to validate input msg")

Expand All @@ -133,18 +146,18 @@ def __init__(
key: bytes,
) -> None:
if create:
_msg_max_size = size - self.min_encap_msg_len
_msg_max_size = size - self.MIN_ENCAP_MSG_LEN
if _msg_max_size < 0:
raise ValueError(f"{size=} < {self.min_encap_msg_len=}")
raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}")
self._shm = shm = mp_shm.SharedMemory(name=name, size=size, create=True)
self.mem_size = shm.size
else:
self._shm = shm = mp_shm.SharedMemory(name=name, create=False)
self.mem_size = size = shm.size
_msg_max_size = size - self.min_encap_msg_len
_msg_max_size = size - self.MIN_ENCAP_MSG_LEN
if _msg_max_size < 0:
shm.close()
raise ValueError(f"{size=} < {self.min_encap_msg_len=}")
raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}")

self.name = shm.name
self._key = key
Expand All @@ -166,11 +179,10 @@ def write_msg(self, obj: T) -> None:
if _pickled_len > self.msg_max_size:
raise ValueError(f"exceed {self.msg_max_size=}: {_pickled_len=}")

_hmac = hmac.digest(key=self._key, msg=_pickled, digest=self.digest_alg)
msg = b"".join(
[
RWLOCK_LOCKED,
_hmac,
self.cal_hmac(_pickled),
_pickled_len.to_bytes(PAYLOAD_LEN_BYTES, "big", signed=False),
_pickled,
]
Expand Down

0 comments on commit c3a9819

Please sign in to comment.