From 0509fbd92cd1f0b00d949f0dd64cdf8c0e02a1ba Mon Sep 17 00:00:00 2001 From: Konstantinos Panayiotou Date: Thu, 19 Dec 2024 14:59:24 +0200 Subject: [PATCH] Major transport fixes for redis and mqtt --- commlib/endpoints.py | 4 ++-- commlib/pubsub.py | 1 + commlib/transports/mqtt.py | 25 +++++++++++-------------- commlib/transports/redis.py | 4 ++-- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/commlib/endpoints.py b/commlib/endpoints.py index c82780b..49285ab 100644 --- a/commlib/endpoints.py +++ b/commlib/endpoints.py @@ -89,7 +89,7 @@ def run(self): self._transport.start() self._state = EndpointState.CONNECTED else: - self.logger().debug( + self.logger().error( f"Transport already connected - cannot run {self.__class__.__name__}") def stop(self) -> None: @@ -110,7 +110,7 @@ def stop(self) -> None: self._transport.stop() self._state = EndpointState.DISCONNECTED else: - self.logger().debug( + self.logger().error( f"Transport is not connected - cannot stop {self.__class__.__name__}") def __del__(self): diff --git a/commlib/pubsub.py b/commlib/pubsub.py index 73692f3..ce00e30 100644 --- a/commlib/pubsub.py +++ b/commlib/pubsub.py @@ -127,6 +127,7 @@ def run(self) -> None: if not self._transport.is_connected and \ self._state not in (EndpointState.CONNECTED, EndpointState.CONNECTING): + self._transport.start() self._main_thread = threading.Thread(target=self.run_forever) self._main_thread.daemon = True self._t_stop_event = threading.Event() diff --git a/commlib/transports/mqtt.py b/commlib/transports/mqtt.py index 8688290..afbb471 100644 --- a/commlib/transports/mqtt.py +++ b/commlib/transports/mqtt.py @@ -340,7 +340,9 @@ def start(self) -> None: Start the event loop. Cannot create any more endpoints from here on. """ - self._client.loop_start() + if not self.is_connected: + self.connect() + self._client.loop_start() def stop(self) -> None: """stop. @@ -378,7 +380,6 @@ def __init__(self, *args, **kwargs): serializer=self._serializer, compression=self._compression, ) - self._transport.connect() def publish(self, msg: PubSubMessage) -> None: """publish. @@ -445,11 +446,10 @@ def __init__(self, *args, **kwargs): serializer=self._serializer, compression=self._compression, ) - self._transport.connect() def run(self): - self._topic = self._transport.subscribe(self._topic, self._on_message) super().run() + self._topic = self._transport.subscribe(self._topic, self._on_message) self.log.debug(f"Started Subscriber: <{self._topic}>") def run_forever(self): @@ -526,7 +526,6 @@ def __init__(self, *args, **kwargs): serializer=self._serializer, compression=self._compression, ) - self._transport.connect() def _send_response(self, data: Dict[str, Any], reply_to: str): self._comm_obj.header.timestamp = gen_timestamp() # pylint: disable=E0237 @@ -575,7 +574,6 @@ def run_forever(self): self._transport.subscribe( self._rpc_name, self._on_request_handle, qos=MQTTQoS.L1 ) - self._transport.start() while True: if self._t_stop_event is not None: if self._t_stop_event.is_set(): @@ -599,11 +597,6 @@ def __init__(self, *args, **kwargs): serializer=self._serializer, compression=self._compression, ) - self._transport.connect() - for uri in self._svc_map: - callback = self._svc_map[uri][0] - msg_type = self._svc_map[uri][1] - self._register_endpoint(uri, callback, msg_type) def _send_response(self, data: Dict[str, Any], reply_to: str): """_send_response. @@ -690,9 +683,14 @@ def _register_endpoint( self.log.info(f"Registering endpoint <{full_uri}>") self._transport.subscribe(full_uri, self._on_request_handle, qos=MQTTQoS.L1) + def _register_endpoints(self): + for uri in self._svc_map: + callback = self._svc_map[uri][0] + msg_type = self._svc_map[uri][1] + self._register_endpoint(uri, callback, msg_type) + def run_forever(self): - """run_forever.""" - self._transport.start() + self._register_endpoints() while True: if self._t_stop_event is not None: if self._t_stop_event.is_set(): @@ -722,7 +720,6 @@ def __init__(self, *args, **kwargs): serializer=self._serializer, compression=self._compression, ) - self._transport.connect() def _gen_queue_name(self): """_gen_queue_name.""" diff --git a/commlib/transports/redis.py b/commlib/transports/redis.py index 63e3e8d..20b10f7 100644 --- a/commlib/transports/redis.py +++ b/commlib/transports/redis.py @@ -84,6 +84,7 @@ def __init__( self._serializer = serializer self._compression = compression self._connected = False + self._rsub = None @property def is_connected(self) -> bool: @@ -225,7 +226,6 @@ def run_forever(self): self._transport.delete_queue(self._rpc_name) while True: msgq, payload = self._transport.wait_for_msg(self._rpc_name, timeout=0) - self._detach_request_handler(payload) if self._t_stop_event is not None: if self._t_stop_event.is_set(): @@ -396,7 +396,6 @@ def __init__(self, queue_size: Optional[int] = 1, *args, **kwargs): self._subscriber_thread = None self._queue_size = queue_size super(Subscriber, self).__init__(*args, **kwargs) - self._transport = RedisTransport( conn_params=self._conn_params, serializer=self._serializer, @@ -404,6 +403,7 @@ def __init__(self, queue_size: Optional[int] = 1, *args, **kwargs): ) def run(self): + super().run() self._subscriber_thread = self._transport.subscribe( self._topic, self._on_message )