diff --git a/nats/aio/client.py b/nats/aio/client.py index c5d7c5e5..7dbc1105 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -96,6 +96,7 @@ DEFAULT_FLUSH_TIMEOUT = 10 # in seconds DEFAULT_CONNECT_TIMEOUT = 2 # in seconds DEFAULT_DRAIN_TIMEOUT = 30 # in seconds +DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024 MAX_CONTROL_LINE_SIZE = 1024 NATS_HDR_LINE = bytearray(b'NATS/1.0') @@ -315,6 +316,7 @@ async def connect( inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX, pending_size: int = DEFAULT_PENDING_SIZE, flush_timeout: Optional[float] = None, + max_message_size: int = DEFAULT_MAX_MESSAGE_SIZE ) -> None: """ Establishes a connection to NATS. @@ -445,6 +447,7 @@ async def subscribe_handler(msg): self.options["token"] = token self.options["connect_timeout"] = connect_timeout self.options["drain_timeout"] = drain_timeout + self.options["max_message_size"] = max_message_size if tls: self.options['tls'] = tls @@ -1309,13 +1312,15 @@ async def _select_next_server(self) -> None: s.uri, ssl_context=self.ssl_context, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options['connect_timeout'], + max_msg_size=self.options['max_message_size'] ) else: await self._transport.connect( s.uri, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options['connect_timeout'], + max_msg_size=self.options['max_message_size'] ) self._current_server = s break diff --git a/nats/aio/transport.py b/nats/aio/transport.py index f32558ba..c0f88f4b 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -18,7 +18,8 @@ class Transport(abc.ABC): @abc.abstractmethod async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int + self, uri: ParseResult, buffer_size: int, connect_timeout: int, + max_message_size: int ): """ Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be @@ -28,11 +29,8 @@ async def connect( @abc.abstractmethod async def connect_tls( - self, - uri: Union[str, ParseResult], - ssl_context: ssl.SSLContext, - buffer_size: int, - connect_timeout: int, + self, uri: Union[str, ParseResult], ssl_context: ssl.SSLContext, + buffer_size: int, connect_timeout: int, max_message_size: int ): """ connect_tls is similar to connect except it tries to connect to a secure endpoint, using the provided ssl @@ -116,7 +114,8 @@ def __init__(self): self._io_writer: Optional[asyncio.StreamWriter] = None async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int + self, uri: ParseResult, buffer_size: int, connect_timeout: int, + max_message_size: int ): r, w = await asyncio.wait_for( asyncio.open_connection( @@ -136,11 +135,8 @@ async def connect( self._bare_io_writer = self._io_writer = w async def connect_tls( - self, - uri: Union[str, ParseResult], - ssl_context: ssl.SSLContext, - buffer_size: int, - connect_timeout: int, + self, uri: Union[str, ParseResult], ssl_context: ssl.SSLContext, + buffer_size: int, connect_timeout: int, max_message_size: int ) -> None: assert self._io_writer, f'{type(self).__name__}.connect must be called first' @@ -203,20 +199,18 @@ def __init__(self): self._using_tls: Optional[bool] = None async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int + self, uri: ParseResult, buffer_size: int, connect_timeout: int, + max_msg_size: int ): # for websocket library, the uri must contain the scheme already self._ws = await self._client.ws_connect( - uri.geturl(), timeout=connect_timeout + uri.geturl(), timeout=connect_timeout, max_msg_size=max_msg_size ) self._using_tls = False async def connect_tls( - self, - uri: Union[str, ParseResult], - ssl_context: ssl.SSLContext, - buffer_size: int, - connect_timeout: int, + self, uri: Union[str, ParseResult], ssl_context: ssl.SSLContext, + buffer_size: int, connect_timeout: int, max_msg_size: int ): if self._ws and not self._ws.closed: if self._using_tls: @@ -226,7 +220,8 @@ async def connect_tls( self._ws = await self._client.ws_connect( uri if isinstance(uri, str) else uri.geturl(), ssl=ssl_context, - timeout=connect_timeout + timeout=connect_timeout, + max_msg_size=max_msg_size ) self._using_tls = True