diff --git a/src/otaclient_common/shm_status.py b/src/otaclient_common/shm_status.py index f20ca7cd4..f3b5c04aa 100644 --- a/src/otaclient_common/shm_status.py +++ b/src/otaclient_common/shm_status.py @@ -50,9 +50,24 @@ class RWBusy(Exception): ... class SHA512Verifier: """Base class for specifying hash alg related configurations.""" - digest_alg = DEFAULT_HASH_ALG - digest_size = hashlib.new(digest_alg).digest_size - min_encap_msg_len = RWLOCK_LEN + digest_size + PAYLOAD_LEN_BYTES + DIGEST_ALG = DEFAULT_HASH_ALG + DIGEST_SIZE = hashlib.new(DIGEST_ALG).digest_size + MIN_ENCAP_MSG_LEN = RWLOCK_LEN + DIGEST_SIZE + PAYLOAD_LEN_BYTES + + _key: bytes + + def cal_hmac(self, _raw_msg: bytes) -> bytes: + return hmac.digest(key=self._key, msg=_raw_msg, digest=self.DIGEST_ALG) + + def verify_msg(self, _raw_msg: bytes, _expected_hmac: bytes) -> bool: + return hmac.compare_digest( + hmac.digest( + key=self._key, + msg=_raw_msg, + digest=self.DIGEST_ALG, + ), + _expected_hmac, + ) class MPSharedStatusReader(SHA512Verifier, Generic[T]): @@ -78,7 +93,7 @@ def __init__( raise ValueError(f"failed to connect share memory with {name=}") self.mem_size = size = shm.size - self.msg_max_size = size - self.min_encap_msg_len + self.msg_max_size = size - self.MIN_ENCAP_MSG_LEN self._key = key def atexit(self) -> None: @@ -103,8 +118,8 @@ def sync_msg(self) -> T: _cursor += RWLOCK_LEN # parsing the msg - input_hmac = bytes(buffer[_cursor : _cursor + self.digest_size]) - _cursor += self.digest_size + input_hmac = bytes(buffer[_cursor : _cursor + self.DIGEST_SIZE]) + _cursor += self.DIGEST_SIZE _payload_len_bytes = bytes(buffer[_cursor : _cursor + PAYLOAD_LEN_BYTES]) payload_len = int.from_bytes(_payload_len_bytes, "big", signed=False) @@ -114,9 +129,7 @@ def sync_msg(self) -> T: raise ValueError(f"invalid msg: {payload_len=} > {self.msg_max_size}") payload = bytes(buffer[_cursor : _cursor + payload_len]) - payload_hmac = hmac.digest(key=self._key, msg=payload, digest=self.digest_alg) - - if hmac.compare_digest(payload_hmac, input_hmac): + if self.verify_msg(payload, input_hmac): return pickle.loads(payload) raise ValueError("failed to validate input msg") @@ -133,18 +146,18 @@ def __init__( key: bytes, ) -> None: if create: - _msg_max_size = size - self.min_encap_msg_len + _msg_max_size = size - self.MIN_ENCAP_MSG_LEN if _msg_max_size < 0: - raise ValueError(f"{size=} < {self.min_encap_msg_len=}") + raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") self._shm = shm = mp_shm.SharedMemory(name=name, size=size, create=True) self.mem_size = shm.size else: self._shm = shm = mp_shm.SharedMemory(name=name, create=False) self.mem_size = size = shm.size - _msg_max_size = size - self.min_encap_msg_len + _msg_max_size = size - self.MIN_ENCAP_MSG_LEN if _msg_max_size < 0: shm.close() - raise ValueError(f"{size=} < {self.min_encap_msg_len=}") + raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") self.name = shm.name self._key = key @@ -166,11 +179,10 @@ def write_msg(self, obj: T) -> None: if _pickled_len > self.msg_max_size: raise ValueError(f"exceed {self.msg_max_size=}: {_pickled_len=}") - _hmac = hmac.digest(key=self._key, msg=_pickled, digest=self.digest_alg) msg = b"".join( [ RWLOCK_LOCKED, - _hmac, + self.cal_hmac(_pickled), _pickled_len.to_bytes(PAYLOAD_LEN_BYTES, "big", signed=False), _pickled, ]