diff --git a/src/otaclient_common/shm_status.py b/src/otaclient_common/shm_status.py index 6cfbc50a7..ed2f5fd4b 100644 --- a/src/otaclient_common/shm_status.py +++ b/src/otaclient_common/shm_status.py @@ -70,6 +70,21 @@ def verify_msg(self, _raw_msg: bytes, _expected_hmac: bytes) -> bool: ) +def _ensure_connect_shm( + name: str, *, max_retry: int, retry_interval: int +) -> mp_shm.SharedMemory: + for _idx in range(max_retry): + try: + return mp_shm.SharedMemory(name=name, create=False) + except Exception as e: + logger.warning( + f"retry #{_idx}: failed to connect to {name=}: {e!r}, keep retrying ..." + ) + time.sleep(retry_interval) + else: + raise ValueError(f"failed to connect share memory with {name=}") + + class MPSharedStatusReader(SHA512Verifier, Generic[T]): def __init__( @@ -80,18 +95,9 @@ def __init__( max_retry: int = 6, retry_interval: int = 1, ) -> None: - for _idx in range(max_retry): - try: - self._shm = shm = mp_shm.SharedMemory(name=name, create=False) - break - except Exception as e: - logger.warning( - f"retry #{_idx}: failed to connect to {name=}: {e!r}, keep retrying ..." - ) - time.sleep(retry_interval) - else: - raise ValueError(f"failed to connect share memory with {name=}") - + self._shm = shm = _ensure_connect_shm( + name, max_retry=max_retry, retry_interval=retry_interval + ) self.mem_size = size = shm.size self.msg_max_size = size - self.MIN_ENCAP_MSG_LEN self._key = key @@ -141,9 +147,11 @@ def __init__( *, name: str | None = None, size: int = 0, + key: bytes, create: bool = False, msg_max_size: int | None = None, - key: bytes, + max_retry: int = 6, + retry_interval: int = 1, ) -> None: if create: _msg_max_size = size - self.MIN_ENCAP_MSG_LEN @@ -151,13 +159,18 @@ def __init__( raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") self._shm = shm = mp_shm.SharedMemory(name=name, size=size, create=True) self.mem_size = shm.size - else: - self._shm = shm = mp_shm.SharedMemory(name=name, create=False) + + elif name: + self._shm = shm = _ensure_connect_shm( + name, max_retry=max_retry, retry_interval=retry_interval + ) self.mem_size = size = shm.size _msg_max_size = size - self.MIN_ENCAP_MSG_LEN if _msg_max_size < 0: shm.close() raise ValueError(f"{size=} < {self.MIN_ENCAP_MSG_LEN=}") + else: + raise ValueError(" must be specified if is False") self.name = shm.name self._key = key