From c7ea581dc80c341946f5e7d3799eb49172a22e84 Mon Sep 17 00:00:00 2001 From: Stanley Kudrow Date: Sat, 5 Oct 2024 17:03:59 +0300 Subject: [PATCH] fix and format the tests/test_client.py module --- tests/test_client.py | 233 +++++++++++++++++-------------------------- 1 file changed, 94 insertions(+), 139 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index dcd96849..a921e034 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ import nats import nats.errors -from nats.aio.client import Client as NATS, __version__ +from nats.aio.client import Client as NATS, ClientStates, __version__ from nats.aio.errors import * from tests.utils import ( ClusteringDiscoveryAuthTestCase, @@ -29,7 +29,6 @@ class ClientUtilsTest(unittest.TestCase): - def test_default_connect_command(self): nc = NATS() nc.options["verbose"] = False @@ -54,7 +53,6 @@ def test_default_connect_command_with_name(self): class ClientTest(SingleServerTestCase): - @async_test async def test_default_connect(self): nc = await nats.connect() @@ -82,10 +80,9 @@ async def test_default_module_connect(self): def test_connect_syntax_sugar(self): nc = NATS() - nc._setup_server_pool([ - "nats://127.0.0.1:4222", "nats://127.0.0.1:4223", - "nats://127.0.0.1:4224" - ]) + nc._setup_server_pool( + ["nats://127.0.0.1:4222", "nats://127.0.0.1:4223", "nats://127.0.0.1:4224"] + ) self.assertEqual(3, len(nc._server_pool)) nc = NATS() @@ -316,9 +313,7 @@ async def subscription_handler(arg1, msg): partial_arg = arg1 msgs.append(msg) - partial_sub_handler = functools.partial( - subscription_handler, "example" - ) + partial_sub_handler = functools.partial(subscription_handler, "example") payload = b"hello world" await nc.connect() @@ -789,19 +784,13 @@ async def slow_worker_handler(msg): await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) - response = await nc.request( - "help", b"please", timeout=1, old_style=True - ) + response = await nc.request("help", b"please", timeout=1, old_style=True) self.assertEqual(b"Reply:1", response.data) - response = await nc.request( - "help", b"please", timeout=1, old_style=True - ) + response = await nc.request("help", b"please", timeout=1, old_style=True) self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): - msg = await nc.request( - "slow.help", b"please", timeout=0.1, old_style=True - ) + msg = await nc.request("slow.help", b"please", timeout=0.1, old_style=True) with self.assertRaises(nats.errors.NoRespondersError): await nc.request("nowhere", b"please", timeout=0.1, old_style=True) @@ -1035,7 +1024,6 @@ async def receiver_cb(msg): class ClientReconnectTest(MultiServerAuthTestCase): - @async_test async def test_connect_with_auth(self): nc = NATS() @@ -1064,9 +1052,7 @@ async def test_module_connect_with_auth(self): @async_test async def test_module_connect_with_options(self): - nc = await nats.connect( - "nats://127.0.0.1:4223", user="foo", password="bar" - ) + nc = await nats.connect("nats://127.0.0.1:4223", user="foo", password="bar") self.assertTrue(nc.is_connected) await nc.drain() self.assertTrue(nc.is_closed) @@ -1083,7 +1069,9 @@ async def err_cb(e): options = { "reconnect_time_wait": 0.2, - "servers": ["nats://hello:world@127.0.0.1:4223", ], + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], "max_reconnect_attempts": 3, "error_cb": err_cb, } @@ -1110,7 +1098,9 @@ async def err_cb(e): options = { "reconnect_time_wait": 0.2, - "servers": ["nats://hello:world@127.0.0.1:4223", ], + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], "max_reconnect_attempts": 3, "error_cb": err_cb, } @@ -1151,12 +1141,8 @@ async def err_cb(e): self.assertTrue(nc.is_connected) # Stop all servers so that there aren't any available to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.2) @@ -1231,12 +1217,8 @@ async def err_cb(e): # Stop all servers so that there aren't any available to reconnect # then start one of them again. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) @@ -1258,9 +1240,7 @@ async def err_cb(e): await asyncio.sleep(0) # Stop the server once again - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) @@ -1527,9 +1507,7 @@ async def worker_handler(msg): self.assertEqual(b"Reply:1", response.data) # Stop the first server and connect to another one asap. - asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) # FIXME: Find better way to wait for the server to be stopped. await asyncio.sleep(0.5) @@ -1547,12 +1525,15 @@ async def worker_handler(msg): class ClientAuthTokenTest(MultiServerAuthTokenTestCase): - @async_test async def test_connect_with_auth_token(self): nc = NATS() - options = {"servers": ["nats://token@127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://token@127.0.0.1:4223", + ] + } await nc.connect(**options) self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) @@ -1565,7 +1546,9 @@ async def test_connect_with_auth_token_option(self): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4223", ], + "servers": [ + "nats://127.0.0.1:4223", + ], "token": "token", } await nc.connect(**options) @@ -1580,7 +1563,9 @@ async def test_connect_with_bad_auth_token(self): nc = NATS() options = { - "servers": ["nats://token@127.0.0.1:4225", ], + "servers": [ + "nats://token@127.0.0.1:4225", + ], "allow_reconnect": False, "reconnect_time_wait": 0.1, "max_reconnect_attempts": 1, @@ -1637,9 +1622,7 @@ async def worker_handler(msg): self.assertTrue(nc.is_connected) # Trigger a reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("test", cb=worker_handler) @@ -1655,7 +1638,6 @@ async def worker_handler(msg): class ClientTLSTest(TLSServerTestCase): - @async_test async def test_connect(self): nc = NATS() @@ -1675,9 +1657,7 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect( - servers=["tls://127.0.0.1:4224"], allow_reconnect=False - ) + await nc.connect(servers=["tls://127.0.0.1:4224"], allow_reconnect=False) @async_test async def test_default_connect_using_tls_scheme_in_url(self): @@ -1730,7 +1710,6 @@ async def subscription_handler(msg): class ClientTLSReconnectTest(MultiTLSServerAuthTestCase): - @async_test async def test_tls_reconnect(self): nc = NATS() @@ -1783,9 +1762,7 @@ async def worker_handler(msg): self.assertEqual(b"Reply:1", response.data) # Trigger a reconnect and should be fine - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("example", cb=worker_handler) @@ -1803,16 +1780,13 @@ async def worker_handler(msg): class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): - @async_test async def test_connect(self): if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = await nats.connect( - "nats://127.0.0.1:4224", - tls=self.ssl_ctx, - tls_handshake_first=True + "nats://127.0.0.1:4224", tls=self.ssl_ctx, tls_handshake_first=True ) self.assertEqual(nc._server_info["max_payload"], nc.max_payload) self.assertTrue(nc._server_info["tls_required"]) @@ -1848,9 +1822,7 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - "tls://127.0.0.1:4224", - allow_reconnect=False, - tls_handshake_first=True + "tls://127.0.0.1:4224", allow_reconnect=False, tls_handshake_first=True ) @async_test @@ -1907,7 +1879,6 @@ async def subscription_handler(msg): class ClusterDiscoveryTest(ClusteringTestCase): - @async_test async def test_discover_servers_on_first_connect(self): nc = NATS() @@ -1923,14 +1894,16 @@ async def test_discover_servers_on_first_connect(self): ) await asyncio.sleep(1) - options = {"servers": ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.Mock() with mock.patch("asyncio.iscoroutinefunction", return_value=True): - await nc.connect( - **options, discovered_server_cb=discovered_server_cb - ) + await nc.connect(**options, discovered_server_cb=discovered_server_cb) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1942,12 +1915,14 @@ async def test_discover_servers_on_first_connect(self): async def test_discover_servers_after_first_connect(self): nc = NATS() - options = {"servers": ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.Mock() with mock.patch("asyncio.iscoroutinefunction", return_value=True): - await nc.connect( - **options, discovered_server_cb=discovered_server_cb - ) + await nc.connect(**options, discovered_server_cb=discovered_server_cb) # Start rest of cluster members so that we receive them # connect_urls on the first connect. @@ -1968,7 +1943,6 @@ async def test_discover_servers_after_first_connect(self): class ClusterDiscoveryReconnectTest(ClusteringDiscoveryAuthTestCase): - @async_test async def test_reconnect_to_new_server_with_auth(self): nc = NATS() @@ -1984,7 +1958,9 @@ async def err_cb(e): errors.append(e) options = { - "servers": ["nats://127.0.0.1:4223", ], + "servers": [ + "nats://127.0.0.1:4223", + ], "reconnected_cb": reconnected_cb, "error_cb": err_cb, "reconnect_time_wait": 0.1, @@ -2002,9 +1978,7 @@ async def handler(msg): await nc.subscribe("foo", cb=handler) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(reconnected, 2) msg = await nc.request("foo", b"hi") @@ -2065,9 +2039,7 @@ async def handler(msg): self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # Publishing while disconnected is an error if pending size is disabled. @@ -2135,9 +2107,7 @@ async def handler(msg): self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # While reconnecting the pending data will accumulate. @@ -2280,10 +2250,8 @@ async def handler(msg): class ConnectFailuresTest(SingleServerTestCase): - @async_test async def test_empty_info_op_uses_defaults(self): - async def bad_server(reader, writer): writer.write(b"INFO {}\r\n") await writer.drain() @@ -2302,7 +2270,9 @@ async def disconnected_cb(): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "disconnected_cb": disconnected_cb, } await nc.connect(**options) @@ -2313,7 +2283,6 @@ async def disconnected_cb(): @async_test async def test_empty_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"") await asyncio.sleep(0.2) @@ -2329,7 +2298,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2341,7 +2312,6 @@ async def error_cb(e): @async_test async def test_malformed_info_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"INF") await asyncio.sleep(0.2) @@ -2357,7 +2327,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2369,7 +2341,6 @@ async def error_cb(e): @async_test async def test_malformed_info_json_response_from_server(self): - async def bad_server(reader, writer): writer.write(b"INFO {\r\n") await asyncio.sleep(0.2) @@ -2385,7 +2356,9 @@ async def error_cb(e): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "error_cb": error_cb, "allow_reconnect": False, } @@ -2398,7 +2371,6 @@ async def error_cb(e): @async_test async def test_connect_timeout(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() @@ -2418,7 +2390,9 @@ async def reconnected_cb(): nc = NATS() options = { - "servers": ["nats://127.0.0.1:4555", ], + "servers": [ + "nats://127.0.0.1:4555", + ], "disconnected_cb": disconnected_cb, "reconnected_cb": reconnected_cb, "connect_timeout": 0.5, @@ -2435,7 +2409,6 @@ async def reconnected_cb(): @async_test async def test_connect_timeout_then_connect_to_healthy_server(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() @@ -2489,7 +2462,6 @@ async def error_cb(e): class ClientDrainTest(SingleServerTestCase): - @async_test async def test_drain_subscription(self): nc = NATS() @@ -2618,9 +2590,7 @@ async def replies(msg): "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "quux", - b"help", - reply=f"my-replies.{nc._nuid.next().decode()}" + "quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) # Relinquish control so that messages are processed. @@ -2680,9 +2650,7 @@ async def closed_cb(): nonlocal drain_done drain_done.set_result(True) - await nc.connect( - closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1 - ) + await nc.connect(closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1) nc2 = NATS() await nc2.connect() @@ -2714,9 +2682,7 @@ async def replies(msg): "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "quux", - b"help", - reply=f"my-replies.{nc._nuid.next().decode()}" + "quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) # Relinquish control so that messages are processed. @@ -2744,11 +2710,11 @@ def f(): pass for cb in [ - "error_cb", - "disconnected_cb", - "discovered_server_cb", - "closed_cb", - "reconnected_cb", + "error_cb", + "disconnected_cb", + "discovered_server_cb", + "closed_cb", + "reconnected_cb", ]: with self.assertRaises(nats.errors.InvalidCallbackTypeError): await nc.connect( @@ -2762,17 +2728,11 @@ def f(): async def test_protocol_mixing(self): nc = NATS() with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"]) @async_test async def test_drain_cancelled_errors_raised(self): @@ -2793,16 +2753,14 @@ async def cb(msg): await asyncio.sleep(0.1) with self.assertRaises(asyncio.CancelledError): with unittest.mock.patch( - "asyncio.wait_for", - unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - ), + "asyncio.wait_for", + unittest.mock.AsyncMock(side_effect=asyncio.CancelledError), ): await sub.drain() await nc.close() class NoAuthUserClientTest(NoAuthUserServerTestCase): - @async_test async def test_connect_user(self): fut = asyncio.Future() @@ -2823,11 +2781,11 @@ async def err_cb(e): await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): await msg.respond(b"pong") @@ -2858,11 +2816,11 @@ async def err_cb(e): await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): await msg.respond(b"pong") @@ -2877,7 +2835,6 @@ async def cb(msg): class ClientDisconnectTest(SingleServerTestCase): - @async_test async def test_close_while_disconnected(self): reconnected = asyncio.Future() @@ -2909,14 +2866,12 @@ async def disconnected_cb(): msg = await sub.next_msg() self.assertEqual(msg.data, b"First") - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) await nc.close() - disconnected_states[0] == NATS.RECONNECTING - disconnected_states[1] == NATS.CLOSED + disconnected_states[0] == ClientStates.RECONNECTING + disconnected_states[1] == ClientStates.CLOSED if __name__ == "__main__":