Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Transport.connect a classmethod #389

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
DEFAULT_SUB_PENDING_MSGS_LIMIT,
Subscription,
)
from .transport import TcpTransport, Transport, WebSocketTransport
from .transport import Transport, TcpTransport, WebSocketTransport

__version__ = '2.2.0'
__lang__ = 'python3'
Expand Down Expand Up @@ -1271,32 +1271,24 @@ async def _select_next_server(self) -> None:
# Not yet exceeded max_reconnect_attempts so can still use
# this server in the future.
self._server_pool.append(s)
if s.last_attempt is not None and now < s.last_attempt + self.options[
"reconnect_time_wait"]:
delay = self.options["reconnect_time_wait"]
if s.last_attempt is not None and now < s.last_attempt + delay:
# Backoff connecting to server if we attempted recently.
await asyncio.sleep(self.options["reconnect_time_wait"])
await asyncio.sleep(delay)
try:
s.last_attempt = time.monotonic()
if not self._transport:
if s.uri.scheme in ("ws", "wss"):
self._transport = WebSocketTransport()
else:
# use TcpTransport as a fallback
self._transport = TcpTransport()
if s.uri.scheme == "wss":
# wss is expected to connect directly with tls
await self._transport.connect_tls(
s.uri,
ssl_context=self.ssl_context,
buffer_size=DEFAULT_BUFFER_SIZE,
connect_timeout=self.options['connect_timeout']
)
transport_class: type[Transport]
if s.uri.scheme in ("ws", "wss"):
transport_class = WebSocketTransport
else:
await self._transport.connect(
s.uri,
buffer_size=DEFAULT_BUFFER_SIZE,
connect_timeout=self.options['connect_timeout']
)
transport_class = TcpTransport
ssl_context = self.ssl_context if s.uri.scheme == "wss" else None
self._transport = await transport_class.connect(
s.uri,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
ssl_context,
)
self._current_server = s
break
except Exception as e:
Expand Down Expand Up @@ -1885,7 +1877,6 @@ async def _process_connect_init(self) -> None:
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)

Expand Down
139 changes: 76 additions & 63 deletions nats/aio/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

class Transport(abc.ABC):

@classmethod
@abc.abstractmethod
async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
):
cls, uri: ParseResult, buffer_size: int, connect_timeout: int,
ssl_context: ssl.SSLContext | None
) -> Transport:
"""
Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be
obtained calling urllib.parse.urlparse.
Expand All @@ -28,7 +30,6 @@ async def connect_tls(
self,
uri: str | ParseResult,
ssl_context: ssl.SSLContext,
buffer_size: int,
connect_timeout: int,
):
"""
Expand All @@ -39,14 +40,14 @@ async def connect_tls(
pass

@abc.abstractmethod
def write(self, payload: bytes):
def write(self, payload: bytes) -> None:
"""
Write bytes to underlying transport. Needs a call to drain() to be successfully written.
"""
pass

@abc.abstractmethod
def writelines(self, payload: list[bytes]):
def writelines(self, payload: list[bytes]) -> None:
"""
Writes a list of bytes, one by one, to the underlying transport. Needs a call to drain() to be successfully
written.
Expand All @@ -69,21 +70,21 @@ async def readline(self) -> bytes:
pass

@abc.abstractmethod
async def drain(self):
async def drain(self) -> None:
"""
Flushes the bytes queued for transmission when calling write() and writelines().
"""
pass

@abc.abstractmethod
async def wait_closed(self):
async def wait_closed(self) -> None:
"""
Waits until the connection is successfully closed.
"""
pass

@abc.abstractmethod
def close(self):
def close(self) -> None:
"""
Closes the underlying transport.
"""
Expand All @@ -97,7 +98,7 @@ def at_eof(self) -> bool:
pass

@abc.abstractmethod
def __bool__(self):
def __bool__(self) -> bool:
"""
Returns if the transport was initialized, either by calling connect of connect_tls.
"""
Expand All @@ -106,41 +107,49 @@ def __bool__(self):

class TcpTransport(Transport):

def __init__(self):
self._bare_io_reader: asyncio.StreamReader | None = None
self._io_reader: asyncio.StreamReader | None = None
self._bare_io_writer: asyncio.StreamWriter | None = None
self._io_writer: asyncio.StreamWriter | None = None
def __init__(
self, r: asyncio.StreamReader, w: asyncio.StreamWriter
) -> None:
self._io_reader: asyncio.StreamReader = r
self._io_writer: asyncio.StreamWriter = w

# We keep a reference to the initial transport we used when
# establishing the connection in case we later upgrade to TLS
# after getting the first INFO message. This is in order to
# prevent the GC closing the socket after we send CONNECT
# and replace the transport.
#
# See https://github.com/nats-io/asyncio-nats/issues/43
self._bare_io_reader: asyncio.StreamReader = r
self._bare_io_writer: asyncio.StreamWriter = w

@classmethod
async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
):
cls, uri: ParseResult, buffer_size: int, connect_timeout: int,
ssl_context: ssl.SSLContext | None
) -> TcpTransport:
r, w = await asyncio.wait_for(
asyncio.open_connection(
host=uri.hostname,
port=uri.port,
limit=buffer_size,
), connect_timeout
)
# We keep a reference to the initial transport we used when
# establishing the connection in case we later upgrade to TLS
# after getting the first INFO message. This is in order to
# prevent the GC closing the socket after we send CONNECT
# and replace the transport.
#
# See https://github.com/nats-io/asyncio-nats/issues/43
self._bare_io_reader = self._io_reader = r
self._bare_io_writer = self._io_writer = w
transport = cls(r, w)
if ssl_context is not None:
await transport.connect_tls(
uri=uri,
ssl_context=ssl_context,
connect_timeout=connect_timeout,
)
return transport

async def connect_tls(
self,
uri: str | ParseResult,
ssl_context: ssl.SSLContext,
buffer_size: int,
connect_timeout: int,
) -> None:
assert self._io_writer, f'{type(self).__name__}.connect must be called first'

# manually recreate the stream reader/writer with a tls upgraded transport
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
Expand All @@ -157,60 +166,65 @@ async def connect_tls(
)
self._io_reader, self._io_writer = reader, writer

def write(self, payload):
return self._io_writer.write(payload)
def write(self, payload: bytes) -> None:
self._io_writer.write(payload)

def writelines(self, payload):
return self._io_writer.writelines(payload)
def writelines(self, payload: list[bytes]) -> None:
self._io_writer.writelines(payload)

async def read(self, buffer_size: int):
assert self._io_reader, f'{type(self).__name__}.connect must be called first'
async def read(self, buffer_size: int) -> bytes:
return await self._io_reader.read(buffer_size)

async def readline(self):
async def readline(self) -> bytes:
return await self._io_reader.readline()

async def drain(self):
return await self._io_writer.drain()
async def drain(self) -> None:
await self._io_writer.drain()

async def wait_closed(self):
async def wait_closed(self) -> None:
return await self._io_writer.wait_closed()

def close(self):
def close(self) -> None:
return self._io_writer.close()

def at_eof(self):
def at_eof(self) -> bool:
return self._io_reader.at_eof()

def __bool__(self):
def __bool__(self) -> bool:
return bool(self._io_writer) and bool(self._io_reader)


class WebSocketTransport(Transport):

def __init__(self):
def __init__(
self, ws: aiohttp.ClientWebSocketResponse,
client: aiohttp.ClientSession
):
self._ws = ws
self._client = client
self._pending: asyncio.Queue[bytes] = asyncio.Queue()
self._close_task: asyncio.Future[bool] = asyncio.Future()

@classmethod
async def connect(
cls, uri: ParseResult, buffer_size: int, connect_timeout: int,
ssl_context: ssl.SSLContext | None
) -> WebSocketTransport:
if not aiohttp:
raise ImportError(
"Could not import aiohttp transport, please install it with `pip install aiohttp`"
)
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._client: aiohttp.ClientSession = aiohttp.ClientSession()
self._pending = asyncio.Queue()
self._close_task = asyncio.Future()

async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
):
client = aiohttp.ClientSession()
# for websocket library, the uri must contain the scheme already
self._ws = await self._client.ws_connect(
uri.geturl(), timeout=connect_timeout
ws = await client.ws_connect(
uri.geturl(), timeout=connect_timeout, ssl=ssl_context
)
return cls(ws, client)

async def connect_tls(
self,
uri: str | ParseResult,
ssl_context: ssl.SSLContext,
buffer_size: int,
connect_timeout: int,
):
self._ws = await self._client.ws_connect(
Expand All @@ -219,39 +233,38 @@ async def connect_tls(
timeout=connect_timeout
)

def write(self, payload):
def write(self, payload: bytes) -> None:
self._pending.put_nowait(payload)

def writelines(self, payload):
def writelines(self, payload: list[bytes]) -> None:
for message in payload:
self.write(message)

async def read(self, buffer_size: int):
async def read(self, buffer_size: int) -> bytes:
return await self.readline()

async def readline(self):
async def readline(self) -> bytes:
data = await self._ws.receive()
if data.type == aiohttp.WSMsgType.CLOSE:
# if the connection terminated abruptly, return empty binary data to raise unexpected EOF
return b''
return data.data

async def drain(self):
async def drain(self) -> None:
# send all the messages pending
while not self._pending.empty():
message = self._pending.get_nowait()
await self._ws.send_bytes(message)

async def wait_closed(self):
async def wait_closed(self) -> None:
await self._close_task
await self._client.close()
self._ws = self._client = None

def close(self):
def close(self) -> None:
self._close_task = asyncio.create_task(self._ws.close())

def at_eof(self):
def at_eof(self) -> bool:
return self._ws.closed

def __bool__(self):
def __bool__(self) -> bool:
return bool(self._client)