diff --git a/nats/aio/client.py b/nats/aio/client.py index 81e65f50..f80766b5 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -16,6 +16,7 @@ import asyncio import base64 +import inspect import ipaddress import json import logging @@ -23,7 +24,7 @@ import string import time from collections import UserString -from dataclasses import dataclass +from dataclasses import dataclass, field, replace, fields from email.parser import BytesParser from io import BytesIO from pathlib import Path @@ -182,7 +183,45 @@ async def _default_error_callback(ex: Exception) -> None: Provides a default way to handle async errors if the user does not provide one. """ - _logger.error("nats: encountered error", exc_info=ex) + _logger.error('nats: encountered error', exc_info=ex) + + +@dataclass +class ConnectOptions: + servers: Union[str, List[str]] = field(default_factory=lambda: ["nats://localhost:4222"]) + error_cb: Optional[ErrorCallback] = None + disconnected_cb: Optional[Callback] = None + closed_cb: Optional[Callback] = None + discovered_server_cb: Optional[Callback] = None + reconnected_cb: Optional[Callback] = None + name: Optional[str] = None + pedantic: bool = False + verbose: bool = False + allow_reconnect: bool = True + connect_timeout: int = DEFAULT_CONNECT_TIMEOUT + reconnect_time_wait: int = DEFAULT_RECONNECT_TIME_WAIT + max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS + ping_interval: int = DEFAULT_PING_INTERVAL + max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS + dont_randomize: bool = False + flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE + no_echo: bool = False + tls: Optional[ssl.SSLContext] = None + tls_hostname: Optional[str] = None + tls_handshake_first: bool = False + user: Optional[str] = None + password: Optional[str] = None + token: Optional[str] = None + drain_timeout: int = DEFAULT_DRAIN_TIMEOUT + signature_cb: Optional[SignatureCallback] = None + user_jwt_cb: Optional[JWTCallback] = None + user_credentials: Optional[Credentials] = None + nkeys_seed: Optional[str] = None + nkeys_seed_str: Optional[str] = None + inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX + pending_size: int = DEFAULT_PENDING_SIZE + flush_timeout: Optional[float] = None + class Client: @@ -319,6 +358,7 @@ async def connect( inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX, pending_size: int = DEFAULT_PENDING_SIZE, flush_timeout: Optional[float] = None, + config: Optional[ConnectOptions] = None ) -> None: """ Establishes a connection to NATS. @@ -409,58 +449,84 @@ async def subscribe_handler(msg): """ + # Get the signature of the connect method + sig = inspect.signature(self.connect) + + # Get the default values from the signature + default_values = { + k: v.default + for k, v in sig.parameters.items() + if v.default is not inspect.Parameter.empty + } + + # Create a dictionary of the arguments and their values + kwargs = {k: v for k, v in locals().items() if k != "self"} + + # Extract the config object from kwargs + config = kwargs.pop("config", None) + + # Override only if the value differs from the default + kwargs = { + k: v + for k, v in kwargs.items() + if k in default_values and v != default_values[k] + } + + config = self._merge_config(config, **kwargs) + + # Set up callbacks for cb in [ - error_cb, - disconnected_cb, - closed_cb, - reconnected_cb, - discovered_server_cb, + config.error_cb, + config.disconnected_cb, + config.closed_cb, + config.reconnected_cb, + config.discovered_server_cb, ]: if cb and not asyncio.iscoroutinefunction(cb): raise errors.InvalidCallbackTypeError - self._setup_server_pool(servers) - self._error_cb = error_cb or _default_error_callback - self._closed_cb = closed_cb - self._discovered_server_cb = discovered_server_cb - self._reconnected_cb = reconnected_cb - self._disconnected_cb = disconnected_cb + self._setup_server_pool(config.servers) + self._error_cb = config.error_cb or _default_error_callback + self._closed_cb = config.closed_cb + self._discovered_server_cb = config.discovered_server_cb + self._reconnected_cb = config.reconnected_cb + self._disconnected_cb = config.disconnected_cb # Custom inbox prefix - if isinstance(inbox_prefix, str): - inbox_prefix = inbox_prefix.encode() - assert isinstance(inbox_prefix, bytes) - self._inbox_prefix = bytearray(inbox_prefix) + if isinstance(config.inbox_prefix, str): + config.inbox_prefix = config.inbox_prefix.encode() + assert isinstance(config.inbox_prefix, bytes) + self._inbox_prefix = bytearray(config.inbox_prefix) # NKEYS support - self._signature_cb = signature_cb - self._user_jwt_cb = user_jwt_cb - self._user_credentials = user_credentials - self._nkeys_seed = nkeys_seed - self._nkeys_seed_str = nkeys_seed_str + self._signature_cb = config.signature_cb + self._user_jwt_cb = config.user_jwt_cb + self._user_credentials = config.user_credentials + self._nkeys_seed = config.nkeys_seed + self._nkeys_seed_str = config.nkeys_seed_str # Customizable options - self.options["verbose"] = verbose - self.options["pedantic"] = pedantic - self.options["name"] = name - self.options["allow_reconnect"] = allow_reconnect - self.options["dont_randomize"] = dont_randomize - self.options["reconnect_time_wait"] = reconnect_time_wait - self.options["max_reconnect_attempts"] = max_reconnect_attempts - self.options["ping_interval"] = ping_interval - self.options["max_outstanding_pings"] = max_outstanding_pings - self.options["no_echo"] = no_echo - self.options["user"] = user - self.options["password"] = password - self.options["token"] = token - self.options["connect_timeout"] = connect_timeout - self.options["drain_timeout"] = drain_timeout - self.options["tls_handshake_first"] = tls_handshake_first - - if tls: - self.options["tls"] = tls - if tls_hostname: - self.options["tls_hostname"] = tls_hostname + self.options["verbose"] = config.verbose + self.options["pedantic"] = config.pedantic + self.options["name"] = config.name + self.options["allow_reconnect"] = config.allow_reconnect + self.options["dont_randomize"] = config.dont_randomize + self.options["reconnect_time_wait"] = config.reconnect_time_wait + self.options["max_reconnect_attempts"] = config.max_reconnect_attempts + self.options["ping_interval"] = config.ping_interval + self.options["max_outstanding_pings"] = config.max_outstanding_pings + self.options["no_echo"] = config.no_echo + self.options["user"] = config.user + self.options["password"] = config.password + self.options["token"] = config.token + self.options["connect_timeout"] = config.connect_timeout + self.options["drain_timeout"] = config.drain_timeout + self.options["tls_handshake_first"] = config.tls_handshake_first + + if config.tls: + self.options["tls"] = config.tls + if config.tls_hostname: + self.options["tls_hostname"] = config.tls_hostname # Check if the username or password was set in the server URI server_auth_configured = False @@ -469,7 +535,7 @@ async def subscribe_handler(msg): if server.uri.username or server.uri.password: server_auth_configured = True break - if user or password or token or server_auth_configured: + if config.user or config.password or config.token or server_auth_configured: self._auth_configured = True if (self._user_credentials is not None or self._nkeys_seed is not None @@ -478,13 +544,13 @@ async def subscribe_handler(msg): self._setup_nkeys_connect() # Queue used to trigger flushes to the socket. - self._flush_queue = asyncio.Queue(maxsize=flusher_queue_size) + self._flush_queue = asyncio.Queue(maxsize=config.flusher_queue_size) # Max size of buffer used for flushing commands to the server. - self._max_pending_size = pending_size + self._max_pending_size = config.pending_size # Max duration for a force flush (happens when a buffer is full). - self._flush_timeout = flush_timeout + self._flush_timeout = config.flush_timeout if self.options["dont_randomize"] is False: shuffle(self._server_pool) @@ -517,6 +583,21 @@ async def subscribe_handler(msg): self._current_server.last_attempt = time.monotonic() self._current_server.reconnects += 1 + def _merge_config( + self, config: Optional[ConnectOptions], **kwargs + ) -> ConnectOptions: + if not config: + config = ConnectOptions() + + defaults = {f.name: f.default for f in fields(ConnectOptions)} + + # Override only if the value differs from the default + updated = { + k: v for k, v in kwargs.items() if k in defaults and v != defaults[k] + } + + return replace(config, **updated) + def _setup_nkeys_connect(self) -> None: if self._user_credentials is not None: self._setup_nkeys_jwt_connect() @@ -1265,7 +1346,7 @@ async def _flush_pending( except asyncio.CancelledError: pass - def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: + def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None: if isinstance(connect_url, str): try: if "nats://" in connect_url or "tls://" in connect_url: