diff --git a/asyncio_mqtt/__init__.py b/asyncio_mqtt/__init__.py index 7bc8547..c90470d 100644 --- a/asyncio_mqtt/__init__.py +++ b/asyncio_mqtt/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: BSD-3-Clause -from .client import Client, ProtocolVersion, TLSParameters, Will +from .client import Client, ProtocolVersion, ProxySettings, TLSParameters, Will from .error import MqttCodeError, MqttError from .version import __version__ @@ -8,6 +8,7 @@ "MqttCodeError", "Client", "Will", + "ProxySettings", "ProtocolVersion", "TLSParameters", "__version__", diff --git a/asyncio_mqtt/client.py b/asyncio_mqtt/client.py index 63cde94..5f96301 100644 --- a/asyncio_mqtt/client.py +++ b/asyncio_mqtt/client.py @@ -21,6 +21,7 @@ Generator, Iterable, Iterator, + List, Tuple, Union, cast, @@ -45,8 +46,8 @@ _PahoSocket = Union[socket.socket, ssl.SSLSocket, mqtt.WebsocketWrapper, Any] WebSocketHeaders = Union[ - Dict[str, Any], - Callable[[Dict[str, Any]], Dict[str, Any]], + Dict[str, str], + Callable[[Dict[str, str]], Dict[str, str]], ] @@ -107,6 +108,13 @@ def __init__( Tuple[int, int, None, int], ] +SubscribeTopic = Union[ + str, + Tuple[str, mqtt.SubscribeOptions], + List[Tuple[str, mqtt.SubscribeOptions]], + List[Tuple[str, int]], +] + P = ParamSpec("P") # TODO: Simplify the logic that surrounds `self._outgoing_calls_sem` with @@ -238,7 +246,7 @@ def __init__( self._client.message_retry_set(message_retry_set) if socket_options is None: socket_options = () - self._socket_options: tuple[SocketOption, ...] = tuple(socket_options) + self._socket_options = tuple(socket_options) @property def id(self) -> str: @@ -300,12 +308,7 @@ async def force_disconnect(self) -> None: @_outgoing_call async def subscribe( self, - topic: ( - str - | tuple[str, mqtt.SubscribeOptions] - | list[tuple[str, mqtt.SubscribeOptions]] - | list[tuple[str, int]] - ), + topic: SubscribeTopic, qos: int = 0, options: mqtt.SubscribeOptions | None = None, properties: Properties | None = None, diff --git a/pyproject.toml b/pyproject.toml index 0c5fee9..f06948a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,3 +5,12 @@ profile = "black" strict = true show_error_codes = true no_strict_concatenate = true # TODO: remove when dropping python 3.7 + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore:ssl.PROTOCOL_TLS is deprecated:DeprecationWarning", +] + +[tool.coverage.run] +branch = true diff --git a/tests/conftest.py b/tests/conftest.py index 2be0893..d7361cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,13 @@ +from __future__ import annotations + import sys -from typing import Any, Dict, Tuple +from typing import Any import pytest @pytest.fixture -def os_and_version() -> str: - return sys.platform + "_" + ".".join(map(str, sys.version_info[:2])) - - -@pytest.fixture -def anyio_backend() -> Tuple[str, Dict[str, Any]]: +def anyio_backend() -> tuple[str, dict[str, Any]]: if sys.platform == "win32": from asyncio.windows_events import WindowsSelectorEventLoopPolicy diff --git a/tests/mosquitto.org.crt b/tests/mosquitto.org.crt new file mode 100644 index 0000000..e76dbd8 --- /dev/null +++ b/tests/mosquitto.org.crt @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIIEAzCCAuugAwIBAgIUBY1hlCGvdj4NhBXkZ/uLUZNILAwwDQYJKoZIhvcNAQEL +BQAwgZAxCzAJBgNVBAYTAkdCMRcwFQYDVQQIDA5Vbml0ZWQgS2luZ2RvbTEOMAwG +A1UEBwwFRGVyYnkxEjAQBgNVBAoMCU1vc3F1aXR0bzELMAkGA1UECwwCQ0ExFjAU +BgNVBAMMDW1vc3F1aXR0by5vcmcxHzAdBgkqhkiG9w0BCQEWEHJvZ2VyQGF0Y2hv +by5vcmcwHhcNMjAwNjA5MTEwNjM5WhcNMzAwNjA3MTEwNjM5WjCBkDELMAkGA1UE +BhMCR0IxFzAVBgNVBAgMDlVuaXRlZCBLaW5nZG9tMQ4wDAYDVQQHDAVEZXJieTES +MBAGA1UECgwJTW9zcXVpdHRvMQswCQYDVQQLDAJDQTEWMBQGA1UEAwwNbW9zcXVp +dHRvLm9yZzEfMB0GCSqGSIb3DQEJARYQcm9nZXJAYXRjaG9vLm9yZzCCASIwDQYJ +KoZIhvcNAQEBBQADggEPADCCAQoCggEBAME0HKmIzfTOwkKLT3THHe+ObdizamPg +UZmD64Tf3zJdNeYGYn4CEXbyP6fy3tWc8S2boW6dzrH8SdFf9uo320GJA9B7U1FW +Te3xda/Lm3JFfaHjkWw7jBwcauQZjpGINHapHRlpiCZsquAthOgxW9SgDgYlGzEA +s06pkEFiMw+qDfLo/sxFKB6vQlFekMeCymjLCbNwPJyqyhFmPWwio/PDMruBTzPH +3cioBnrJWKXc3OjXdLGFJOfj7pP0j/dr2LH72eSvv3PQQFl90CZPFhrCUcRHSSxo +E6yjGOdnz7f6PveLIB574kQORwt8ePn0yidrTC1ictikED3nHYhMUOUCAwEAAaNT +MFEwHQYDVR0OBBYEFPVV6xBUFPiGKDyo5V3+Hbh4N9YSMB8GA1UdIwQYMBaAFPVV +6xBUFPiGKDyo5V3+Hbh4N9YSMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEL +BQADggEBAGa9kS21N70ThM6/Hj9D7mbVxKLBjVWe2TPsGfbl3rEDfZ+OKRZ2j6AC +6r7jb4TZO3dzF2p6dgbrlU71Y/4K0TdzIjRj3cQ3KSm41JvUQ0hZ/c04iGDg/xWf ++pp58nfPAYwuerruPNWmlStWAXf0UTqRtg4hQDWBuUFDJTuWuuBvEXudz74eh/wK +sMwfu1HFvjy5Z0iMDU8PUDepjVolOCue9ashlS4EB5IECdSR2TItnAIiIwimx839 +LdUdRudafMu5T5Xma182OC0/u/xRlEm+tvKGGmfFcN0piqVl8OrSPBgIlb+1IKJE +m/XriWr/Cq4h/JfB7NTsezVslgkBaoU= +-----END CERTIFICATE----- diff --git a/tests/test_client.py b/tests/test_client.py index 3f1f0f3..8356045 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,15 +1,27 @@ +from __future__ import annotations + +import logging +import ssl +import sys +from pathlib import Path + import anyio import anyio.abc +import paho.mqtt.client as mqtt import pytest -from asyncio_mqtt import Client -from asyncio_mqtt.client import ProtocolVersion, Will +from asyncio_mqtt import Client, ProtocolVersion, TLSParameters, Will +from asyncio_mqtt.types import PayloadType pytestmark = pytest.mark.anyio +HOSTNAME = "test.mosquitto.org" +OS_PY_VERSION = sys.platform + "_" + ".".join(map(str, sys.version_info[:2])) +TOPIC_HEADER = OS_PY_VERSION + "/tests/asyncio_mqtt/" -async def test_client_filtered_messages(os_and_version: str) -> None: - topic_header = os_and_version + "/tests/asyncio_mqtt/filtered_messages/" + +async def test_client_filtered_messages() -> None: + topic_header = TOPIC_HEADER + "filtered_messages/" good_topic = topic_header + "good" bad_topic = topic_header + "bad" @@ -19,7 +31,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: assert message.topic == good_topic tg.cancel_scope.cancel() - async with Client("test.mosquitto.org") as client: + async with Client(HOSTNAME) as client: async with anyio.create_task_group() as tg: await client.subscribe(topic_header + "#") tg.start_soon(handle_messages, tg) @@ -27,8 +39,8 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: await client.publish(good_topic, 2) -async def test_client_unfiltered_messages(os_and_version: str) -> None: - topic_header = os_and_version + "/tests/asyncio_mqtt/unfiltered_messages/" +async def test_client_unfiltered_messages() -> None: + topic_header = TOPIC_HEADER + "unfiltered_messages/" topic_filtered = topic_header + "filtered" topic_unfiltered = topic_header + "unfiltered" @@ -43,7 +55,7 @@ async def handle_filtered_messages() -> None: async for message in messages: assert message.topic == topic_filtered - async with Client("test.mosquitto.org") as client: + async with Client(HOSTNAME) as client: async with anyio.create_task_group() as tg: await client.subscribe(topic_header + "#") tg.start_soon(handle_filtered_messages) @@ -52,30 +64,27 @@ async def handle_filtered_messages() -> None: await client.publish(topic_unfiltered, 2) -async def test_client_unsubscribe(os_and_version: str) -> None: - topic_header = os_and_version + "/tests/asyncio_mqtt/unsubscribe/" +async def test_client_unsubscribe() -> None: + topic_header = TOPIC_HEADER + "unsubscribe/" topic1 = topic_header + "1" topic2 = topic_header + "2" - event = anyio.Event() async def handle_messages(tg: anyio.abc.TaskGroup) -> None: async with client.unfiltered_messages() as messages: - event.set() - i = 0 + is_first_message = True async for message in messages: - if i == 0: + if is_first_message: assert message.topic == topic1 - elif i == 1: + is_first_message = False + else: assert message.topic == topic2 tg.cancel_scope.cancel() - i += 1 - async with Client("test.mosquitto.org") as client: + async with Client(HOSTNAME) as client: async with anyio.create_task_group() as tg: await client.subscribe(topic1) await client.subscribe(topic2) tg.start_soon(handle_messages, tg) - await client.publish(topic1, 2) await client.unsubscribe(topic1) await client.publish(topic1, 2) @@ -87,17 +96,17 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: ((ProtocolVersion.V31, 22), (ProtocolVersion.V311, 0), (ProtocolVersion.V5, 0)), ) async def test_client_id(protocol: ProtocolVersion, length: int) -> None: - client = Client("test.mosquitto.org", protocol=protocol) + client = Client(HOSTNAME, protocol=protocol) assert len(client.id) == length async def test_client_will() -> None: - topic = "tests/asyncio_mqtt/will" + topic = TOPIC_HEADER + "will" event = anyio.Event() async def launch_client() -> None: with anyio.CancelScope(shield=True) as cs: - async with Client("test.mosquitto.org") as client: + async with Client(HOSTNAME) as client: await client.subscribe(topic) event.set() async with client.filtered_messages(topic) as messages: @@ -108,5 +117,174 @@ async def launch_client() -> None: async with anyio.create_task_group() as tg: tg.start_soon(launch_client) await event.wait() - async with Client("test.mosquitto.org", will=Will(topic)) as client: + async with Client(HOSTNAME, will=Will(topic)) as client: client._client._sock_close() # type: ignore[attr-defined] + + +async def test_client_tls_context() -> None: + topic = TOPIC_HEADER + "tls_context" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.filtered_messages(topic) as messages: + async for message in messages: + assert message.topic == topic + tg.cancel_scope.cancel() + + async with Client( + HOSTNAME, + 8883, + tls_context=ssl.SSLContext(protocol=ssl.PROTOCOL_TLS), + ) as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic) + tg.start_soon(handle_messages, tg) + await client.publish(topic) + + +async def test_client_tls_params() -> None: + topic = TOPIC_HEADER + "tls_params" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.filtered_messages(topic) as messages: + async for message in messages: + assert message.topic == topic + tg.cancel_scope.cancel() + + async with Client( + HOSTNAME, + 8883, + tls_params=TLSParameters( + ca_certs=str(Path.cwd() / "tests" / "mosquitto.org.crt") + ), + ) as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic) + tg.start_soon(handle_messages, tg) + await client.publish(topic) + + +async def test_client_username_password() -> None: + topic = TOPIC_HEADER + "username_password" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.filtered_messages(topic) as messages: + async for message in messages: + assert message.topic == topic + tg.cancel_scope.cancel() + + async with Client(HOSTNAME, username="asyncio-mqtt", password="012") as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic) + tg.start_soon(handle_messages, tg) + await client.publish(topic) + + +async def test_client_logger() -> None: + logger = logging.getLogger("asyncio-mqtt") + async with Client(HOSTNAME, logger=logger) as client: + assert logger is client._client._logger # type: ignore[attr-defined] + + +async def test_client_max_concurrent_outgoing_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + topic = TOPIC_HEADER + "max_concurrent_outgoing_calls" + + class MockPahoClient(mqtt.Client): + def subscribe( + self, + topic: str + | tuple[str, mqtt.SubscribeOptions] + | list[tuple[str, mqtt.SubscribeOptions]] + | list[tuple[str, int]], + qos: int = 0, + options: mqtt.SubscribeOptions | None = None, + properties: mqtt.Properties | None = None, + ) -> tuple[int, int]: + assert client._outgoing_calls_sem is not None + assert client._outgoing_calls_sem.locked() + return super().subscribe(topic, qos, options, properties) + + def unsubscribe( + self, topic: str | list[str], properties: mqtt.Properties | None = None + ) -> tuple[int, int]: + assert client._outgoing_calls_sem is not None + assert client._outgoing_calls_sem.locked() + return super().unsubscribe(topic, properties) + + def publish( + self, + topic: str, + payload: PayloadType | None = None, + qos: int = 0, + retain: bool = False, + properties: mqtt.Properties | None = None, + ) -> mqtt.MQTTMessageInfo: + assert client._outgoing_calls_sem is not None + assert client._outgoing_calls_sem.locked() + return super().publish(topic, payload, qos, retain, properties) + + monkeypatch.setattr(mqtt, "Client", MockPahoClient) + + async with Client(HOSTNAME, max_concurrent_outgoing_calls=1) as client: + await client.subscribe(topic) + await client.unsubscribe(topic) + await client.publish(topic) + + +async def test_client_websockets() -> None: + topic = TOPIC_HEADER + "websockets" + + async def handle_messages(tg: anyio.abc.TaskGroup) -> None: + async with client.filtered_messages(topic) as messages: + async for message in messages: + assert message.topic == topic + tg.cancel_scope.cancel() + + async with Client( + HOSTNAME, + 8080, + transport="websockets", + websocket_path="/", + websocket_headers={"foo": "bar"}, + ) as client: + async with anyio.create_task_group() as tg: + await client.subscribe(topic) + tg.start_soon(handle_messages, tg) + await client.publish(topic) + + +async def test_client_pending_calls_threshold(caplog: pytest.LogCaptureFixture) -> None: + topic = TOPIC_HEADER + "pending_calls_threshold" + + async with Client(HOSTNAME) as client: + nb_publish = client._pending_calls_threshold + 1 + + async with anyio.create_task_group() as tg: + for _ in range(nb_publish): + tg.start_soon(client.publish, topic) + + assert caplog.record_tuples == [ + ( + "mqtt", + logging.WARNING, + f"There are {nb_publish} pending publish calls.", + ) + ] + + +async def test_client_no_pending_calls_warnings_with_max_concurrent_outgoing_calls( + caplog: pytest.LogCaptureFixture, +) -> None: + topic = ( + TOPIC_HEADER + "no_pending_calls_warnings_with_max_concurrent_outgoing_calls" + ) + + async with Client(HOSTNAME, max_concurrent_outgoing_calls=1) as client: + nb_publish = client._pending_calls_threshold + 1 + + async with anyio.create_task_group() as tg: + for _ in range(nb_publish): + tg.start_soon(client.publish, topic) + + assert caplog.record_tuples == [] diff --git a/tests/test_error.py b/tests/test_error.py new file mode 100644 index 0000000..6d38501 --- /dev/null +++ b/tests/test_error.py @@ -0,0 +1,72 @@ +import paho.mqtt.client as mqtt +import pytest +from paho.mqtt.packettypes import PacketTypes + +from asyncio_mqtt.error import _CONNECT_RC_STRINGS, MqttCodeError, MqttConnectError + + +@pytest.mark.parametrize( + "rc", + ( + mqtt.MQTT_ERR_SUCCESS, + mqtt.MQTT_ERR_NOMEM, + mqtt.MQTT_ERR_PROTOCOL, + mqtt.MQTT_ERR_INVAL, + mqtt.MQTT_ERR_NO_CONN, + mqtt.MQTT_ERR_CONN_REFUSED, + mqtt.MQTT_ERR_NOT_FOUND, + mqtt.MQTT_ERR_CONN_LOST, + mqtt.MQTT_ERR_TLS, + mqtt.MQTT_ERR_PAYLOAD_SIZE, + mqtt.MQTT_ERR_NOT_SUPPORTED, + mqtt.MQTT_ERR_AUTH, + mqtt.MQTT_ERR_ACL_DENIED, + mqtt.MQTT_ERR_UNKNOWN, + mqtt.MQTT_ERR_ERRNO, + mqtt.MQTT_ERR_QUEUE_SIZE, + mqtt.MQTT_ERR_KEEPALIVE, + -1, + ), +) +def test_mqtt_code_error_int(rc: int) -> None: + assert str(MqttCodeError(rc)) == f"[code:{rc}] {mqtt.error_string(rc)}" + + +@pytest.mark.parametrize( + "packetType, aName", + ( + (PacketTypes.CONNACK, "Success"), + (PacketTypes.PUBACK, "Success"), + (PacketTypes.SUBACK, "Granted QoS 1"), + ), +) +def test_mqtt_code_error_reason_codes(packetType: int, aName: str) -> None: + rc = mqtt.ReasonCodes(packetType, aName) + assert str(MqttCodeError(rc)) == f"[code:{rc.value}] {str(rc)}" + + +def test_mqtt_code_error_none() -> None: + assert str(MqttCodeError(None)) == "[code:None] " + + +@pytest.mark.parametrize("rc, message", list(_CONNECT_RC_STRINGS.items()) + [(0, "")]) +def test_mqtt_connect_error_int(rc: int, message: str) -> None: + error = MqttConnectError(rc) + arg = "Connection refused" + if rc in _CONNECT_RC_STRINGS: + arg += f": {message}" + assert error.args[0] == arg + assert str(error) == f"[code:{rc}] {mqtt.error_string(rc)}" + + +@pytest.mark.parametrize( + "packetType, aName", + ( + (PacketTypes.CONNACK, "Success"), + (PacketTypes.PUBACK, "Success"), + (PacketTypes.SUBACK, "Granted QoS 1"), + ), +) +def test_mqtt_connect_error_reason_codes(packetType: int, aName: str) -> None: + rc = mqtt.ReasonCodes(packetType, aName) + assert str(MqttConnectError(rc)) == f"[code:{rc.value}] {str(rc)}"