From e056c8ad4ef73acc1f88dffdcb2e54aab5323a7e Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 23 Nov 2022 09:22:15 +0100 Subject: [PATCH 1/4] Decouple .connect from Transport --- nats/aio/client.py | 39 ++++------- nats/aio/transport.py | 159 +++++++++++++++++++++--------------------- 2 files changed, 95 insertions(+), 103 deletions(-) diff --git a/nats/aio/client.py b/nats/aio/client.py index 46b19638..d9c91385 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -53,7 +53,7 @@ DEFAULT_SUB_PENDING_MSGS_LIMIT, Subscription, ) -from .transport import TcpTransport, Transport, WebSocketTransport +from .transport import Transport, connect_tcp, connect_ws, Connector __version__ = '2.2.0' __lang__ = 'python3' @@ -1271,32 +1271,24 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if s.last_attempt is not None and now < s.last_attempt + self.options[ - "reconnect_time_wait"]: + delay = self.options["reconnect_time_wait"] + if s.last_attempt is not None and now < s.last_attempt + delay: # Backoff connecting to server if we attempted recently. - await asyncio.sleep(self.options["reconnect_time_wait"]) + await asyncio.sleep(delay) try: s.last_attempt = time.monotonic() - if not self._transport: - if s.uri.scheme in ("ws", "wss"): - self._transport = WebSocketTransport() - else: - # use TcpTransport as a fallback - self._transport = TcpTransport() - if s.uri.scheme == "wss": - # wss is expected to connect directly with tls - await self._transport.connect_tls( - s.uri, - ssl_context=self.ssl_context, - buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] - ) + connector: Connector + if s.uri.scheme in ("ws", "wss"): + connector = connect_ws else: - await self._transport.connect( - s.uri, - buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] - ) + connector = connect_tcp + ssl_context = self.ssl_context if s.uri.scheme == "wss" else None + self._transport = await connector( + s.uri, + DEFAULT_BUFFER_SIZE, + self.options['connect_timeout'], + ssl_context, + ) self._current_server = s break except Exception as e: @@ -1885,7 +1877,6 @@ async def _process_connect_init(self) -> None: await self._transport.connect_tls( hostname, self.ssl_context, - DEFAULT_BUFFER_SIZE, self.options['connect_timeout'], ) diff --git a/nats/aio/transport.py b/nats/aio/transport.py index e5952eaf..20ca9aee 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -3,6 +3,7 @@ import abc import asyncio import ssl +from typing import Awaitable, Callable from urllib.parse import ParseResult try: @@ -12,23 +13,11 @@ class Transport(abc.ABC): - - @abc.abstractmethod - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): - """ - Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be - obtained calling urllib.parse.urlparse. - """ - pass - @abc.abstractmethod async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ): """ @@ -39,14 +28,14 @@ async def connect_tls( pass @abc.abstractmethod - def write(self, payload: bytes): + def write(self, payload: bytes) -> None: """ Write bytes to underlying transport. Needs a call to drain() to be successfully written. """ pass @abc.abstractmethod - def writelines(self, payload: list[bytes]): + def writelines(self, payload: list[bytes]) -> None: """ Writes a list of bytes, one by one, to the underlying transport. Needs a call to drain() to be successfully written. @@ -69,21 +58,21 @@ async def readline(self) -> bytes: pass @abc.abstractmethod - async def drain(self): + async def drain(self) -> None: """ Flushes the bytes queued for transmission when calling write() and writelines(). """ pass @abc.abstractmethod - async def wait_closed(self): + async def wait_closed(self) -> None: """ Waits until the connection is successfully closed. """ pass @abc.abstractmethod - def close(self): + def close(self) -> None: """ Closes the underlying transport. """ @@ -97,31 +86,45 @@ def at_eof(self) -> bool: pass @abc.abstractmethod - def __bool__(self): + def __bool__(self) -> bool: """ Returns if the transport was initialized, either by calling connect of connect_tls. """ pass +Connector = Callable[[ParseResult, int, int, 'ssl.SSLContext | None'], Awaitable[Transport]] + + +async def connect_tcp( + uri: ParseResult, + buffer_size: int, + connect_timeout: int, + ssl_context: ssl.SSLContext | None +) -> TcpTransport: + r, w = await asyncio.wait_for( + asyncio.open_connection( + host=uri.hostname, + port=uri.port, + limit=buffer_size, + ), connect_timeout + ) + transport = TcpTransport(r, w) + if ssl_context is not None: + await transport.connect_tls( + uri=uri, + ssl_context=ssl_context, + connect_timeout=connect_timeout, + ) + return transport + + class TcpTransport(Transport): - def __init__(self): - self._bare_io_reader: asyncio.StreamReader | None = None - self._io_reader: asyncio.StreamReader | None = None - self._bare_io_writer: asyncio.StreamWriter | None = None - self._io_writer: asyncio.StreamWriter | None = None + def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + self._io_reader: asyncio.StreamReader = r + self._io_writer: asyncio.StreamWriter = w - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): - r, w = await asyncio.wait_for( - asyncio.open_connection( - host=uri.hostname, - port=uri.port, - limit=buffer_size, - ), connect_timeout - ) # We keep a reference to the initial transport we used when # establishing the connection in case we later upgrade to TLS # after getting the first INFO message. This is in order to @@ -129,18 +132,15 @@ async def connect( # and replace the transport. # # See https://github.com/nats-io/asyncio-nats/issues/43 - self._bare_io_reader = self._io_reader = r - self._bare_io_writer = self._io_writer = w + self._bare_io_reader: asyncio.StreamReader = r + self._bare_io_writer: asyncio.StreamWriter = w async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ) -> None: - assert self._io_writer, f'{type(self).__name__}.connect must be called first' - # manually recreate the stream reader/writer with a tls upgraded transport reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) @@ -157,60 +157,62 @@ async def connect_tls( ) self._io_reader, self._io_writer = reader, writer - def write(self, payload): - return self._io_writer.write(payload) + def write(self, payload: bytes) -> None: + self._io_writer.write(payload) - def writelines(self, payload): - return self._io_writer.writelines(payload) + def writelines(self, payload: list[bytes]) -> None: + self._io_writer.writelines(payload) - async def read(self, buffer_size: int): - assert self._io_reader, f'{type(self).__name__}.connect must be called first' + async def read(self, buffer_size: int) -> bytes: return await self._io_reader.read(buffer_size) - async def readline(self): + async def readline(self) -> bytes: return await self._io_reader.readline() - async def drain(self): - return await self._io_writer.drain() + async def drain(self) -> None: + await self._io_writer.drain() - async def wait_closed(self): + async def wait_closed(self) -> None: return await self._io_writer.wait_closed() - def close(self): + def close(self) -> None: return self._io_writer.close() - def at_eof(self): + def at_eof(self) -> bool: return self._io_reader.at_eof() - def __bool__(self): + def __bool__(self) -> bool: return bool(self._io_writer) and bool(self._io_reader) +async def connect_ws( + uri: ParseResult, + buffer_size: int, + connect_timeout: int, + ssl_context: ssl.SSLContext | None +) -> WebSocketTransport: + if not aiohttp: + raise ImportError( + "Could not import aiohttp transport, please install it with `pip install aiohttp`" + ) + client = aiohttp.ClientSession() + # for websocket library, the uri must contain the scheme already + ws = await client.ws_connect(uri.geturl(), timeout=connect_timeout, ssl=ssl_context) + return WebSocketTransport(ws, client) + + class WebSocketTransport(Transport): - def __init__(self): - if not aiohttp: - raise ImportError( - "Could not import aiohttp transport, please install it with `pip install aiohttp`" - ) - self._ws: aiohttp.ClientWebSocketResponse | None = None - self._client: aiohttp.ClientSession = aiohttp.ClientSession() - self._pending = asyncio.Queue() - self._close_task = asyncio.Future() - - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): - # for websocket library, the uri must contain the scheme already - self._ws = await self._client.ws_connect( - uri.geturl(), timeout=connect_timeout - ) + def __init__(self, ws: aiohttp.ClientWebSocketResponse, client: aiohttp.ClientSession): + self._ws = ws + self._client = client + self._pending: asyncio.Queue[bytes] = asyncio.Queue() + self._close_task: asyncio.Future[bool] = asyncio.Future() async def connect_tls( self, uri: str | ParseResult, ssl_context: ssl.SSLContext, - buffer_size: int, connect_timeout: int, ): self._ws = await self._client.ws_connect( @@ -219,39 +221,38 @@ async def connect_tls( timeout=connect_timeout ) - def write(self, payload): + def write(self, payload: bytes) -> None: self._pending.put_nowait(payload) - def writelines(self, payload): + def writelines(self, payload: list[bytes]) -> None: for message in payload: self.write(message) - async def read(self, buffer_size: int): + async def read(self, buffer_size: int) -> bytes: return await self.readline() - async def readline(self): + async def readline(self) -> bytes: data = await self._ws.receive() if data.type == aiohttp.WSMsgType.CLOSE: # if the connection terminated abruptly, return empty binary data to raise unexpected EOF return b'' return data.data - async def drain(self): + async def drain(self) -> None: # send all the messages pending while not self._pending.empty(): message = self._pending.get_nowait() await self._ws.send_bytes(message) - async def wait_closed(self): + async def wait_closed(self) -> None: await self._close_task await self._client.close() - self._ws = self._client = None - def close(self): + def close(self) -> None: self._close_task = asyncio.create_task(self._ws.close()) - def at_eof(self): + def at_eof(self) -> bool: return self._ws.closed - def __bool__(self): + def __bool__(self) -> bool: return bool(self._client) From d831f32def513b8f3914de92e3319d82981d87d0 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 23 Nov 2022 09:53:12 +0100 Subject: [PATCH 2/4] Make Transport.connect a classmethod --- nats/aio/client.py | 10 ++--- nats/aio/transport.py | 99 ++++++++++++++++++++++++------------------- 2 files changed, 61 insertions(+), 48 deletions(-) diff --git a/nats/aio/client.py b/nats/aio/client.py index d9c91385..3d4bd51a 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -53,7 +53,7 @@ DEFAULT_SUB_PENDING_MSGS_LIMIT, Subscription, ) -from .transport import Transport, connect_tcp, connect_ws, Connector +from .transport import Transport, TcpTransport, WebSocketTransport __version__ = '2.2.0' __lang__ = 'python3' @@ -1277,13 +1277,13 @@ async def _select_next_server(self) -> None: await asyncio.sleep(delay) try: s.last_attempt = time.monotonic() - connector: Connector + transport_class: type[Transport] if s.uri.scheme in ("ws", "wss"): - connector = connect_ws + transport_class = WebSocketTransport else: - connector = connect_tcp + transport_class = TcpTransport ssl_context = self.ssl_context if s.uri.scheme == "wss" else None - self._transport = await connector( + self._transport = await transport_class.connect( s.uri, DEFAULT_BUFFER_SIZE, self.options['connect_timeout'], diff --git a/nats/aio/transport.py b/nats/aio/transport.py index 20ca9aee..62997ef3 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -3,7 +3,6 @@ import abc import asyncio import ssl -from typing import Awaitable, Callable from urllib.parse import ParseResult try: @@ -13,6 +12,21 @@ class Transport(abc.ABC): + @classmethod + @abc.abstractmethod + async def connect( + cls, + uri: ParseResult, + buffer_size: int, + connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> Transport: + """ + Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be + obtained calling urllib.parse.urlparse. + """ + pass + @abc.abstractmethod async def connect_tls( self, @@ -93,32 +107,6 @@ def __bool__(self) -> bool: pass -Connector = Callable[[ParseResult, int, int, 'ssl.SSLContext | None'], Awaitable[Transport]] - - -async def connect_tcp( - uri: ParseResult, - buffer_size: int, - connect_timeout: int, - ssl_context: ssl.SSLContext | None -) -> TcpTransport: - r, w = await asyncio.wait_for( - asyncio.open_connection( - host=uri.hostname, - port=uri.port, - limit=buffer_size, - ), connect_timeout - ) - transport = TcpTransport(r, w) - if ssl_context is not None: - await transport.connect_tls( - uri=uri, - ssl_context=ssl_context, - connect_timeout=connect_timeout, - ) - return transport - - class TcpTransport(Transport): def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: @@ -135,6 +123,30 @@ def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: self._bare_io_reader: asyncio.StreamReader = r self._bare_io_writer: asyncio.StreamWriter = w + @classmethod + async def connect( + cls, + uri: ParseResult, + buffer_size: int, + connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> TcpTransport: + r, w = await asyncio.wait_for( + asyncio.open_connection( + host=uri.hostname, + port=uri.port, + limit=buffer_size, + ), connect_timeout + ) + transport = cls(r, w) + if ssl_context is not None: + await transport.connect_tls( + uri=uri, + ssl_context=ssl_context, + connect_timeout=connect_timeout, + ) + return transport + async def connect_tls( self, uri: str | ParseResult, @@ -185,22 +197,6 @@ def __bool__(self) -> bool: return bool(self._io_writer) and bool(self._io_reader) -async def connect_ws( - uri: ParseResult, - buffer_size: int, - connect_timeout: int, - ssl_context: ssl.SSLContext | None -) -> WebSocketTransport: - if not aiohttp: - raise ImportError( - "Could not import aiohttp transport, please install it with `pip install aiohttp`" - ) - client = aiohttp.ClientSession() - # for websocket library, the uri must contain the scheme already - ws = await client.ws_connect(uri.geturl(), timeout=connect_timeout, ssl=ssl_context) - return WebSocketTransport(ws, client) - - class WebSocketTransport(Transport): def __init__(self, ws: aiohttp.ClientWebSocketResponse, client: aiohttp.ClientSession): @@ -209,6 +205,23 @@ def __init__(self, ws: aiohttp.ClientWebSocketResponse, client: aiohttp.ClientSe self._pending: asyncio.Queue[bytes] = asyncio.Queue() self._close_task: asyncio.Future[bool] = asyncio.Future() + @classmethod + async def connect( + cls, + uri: ParseResult, + buffer_size: int, + connect_timeout: int, + ssl_context: ssl.SSLContext | None + ) -> WebSocketTransport: + if not aiohttp: + raise ImportError( + "Could not import aiohttp transport, please install it with `pip install aiohttp`" + ) + client = aiohttp.ClientSession() + # for websocket library, the uri must contain the scheme already + ws = await client.ws_connect(uri.geturl(), timeout=connect_timeout, ssl=ssl_context) + return cls(ws, client) + async def connect_tls( self, uri: str | ParseResult, From 59e71fa865ee46ae3a08ddbd6d11bc47886437d9 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 23 Nov 2022 09:53:38 +0100 Subject: [PATCH 3/4] format code --- nats/aio/transport.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/nats/aio/transport.py b/nats/aio/transport.py index 62997ef3..d1d7c464 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -12,13 +12,11 @@ class Transport(abc.ABC): + @classmethod @abc.abstractmethod async def connect( - cls, - uri: ParseResult, - buffer_size: int, - connect_timeout: int, + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, ssl_context: ssl.SSLContext | None ) -> Transport: """ @@ -109,7 +107,9 @@ def __bool__(self) -> bool: class TcpTransport(Transport): - def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: + def __init__( + self, r: asyncio.StreamReader, w: asyncio.StreamWriter + ) -> None: self._io_reader: asyncio.StreamReader = r self._io_writer: asyncio.StreamWriter = w @@ -125,10 +125,7 @@ def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: @classmethod async def connect( - cls, - uri: ParseResult, - buffer_size: int, - connect_timeout: int, + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, ssl_context: ssl.SSLContext | None ) -> TcpTransport: r, w = await asyncio.wait_for( @@ -199,7 +196,10 @@ def __bool__(self) -> bool: class WebSocketTransport(Transport): - def __init__(self, ws: aiohttp.ClientWebSocketResponse, client: aiohttp.ClientSession): + def __init__( + self, ws: aiohttp.ClientWebSocketResponse, + client: aiohttp.ClientSession + ): self._ws = ws self._client = client self._pending: asyncio.Queue[bytes] = asyncio.Queue() @@ -207,10 +207,7 @@ def __init__(self, ws: aiohttp.ClientWebSocketResponse, client: aiohttp.ClientSe @classmethod async def connect( - cls, - uri: ParseResult, - buffer_size: int, - connect_timeout: int, + cls, uri: ParseResult, buffer_size: int, connect_timeout: int, ssl_context: ssl.SSLContext | None ) -> WebSocketTransport: if not aiohttp: @@ -219,7 +216,9 @@ async def connect( ) client = aiohttp.ClientSession() # for websocket library, the uri must contain the scheme already - ws = await client.ws_connect(uri.geturl(), timeout=connect_timeout, ssl=ssl_context) + ws = await client.ws_connect( + uri.geturl(), timeout=connect_timeout, ssl=ssl_context + ) return cls(ws, client) async def connect_tls( From c5ea730f4cb3608e9e5fc6fb469738ac0ce21d23 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 25 Nov 2022 10:36:57 +0100 Subject: [PATCH 4/4] trigger CI