diff --git a/README.md b/README.md index 1f6ca96c..48d647f5 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ async def main(): nc = await nats.connect("localhost") # Create JetStream context. - js = nc.jetstream() + js = nats.jetstream.new(nc) # Persist messages on 'foo's subject. await js.add_stream(name="sample-stream", subjects=["foo"]) diff --git a/examples/jetstream.py b/examples/jetstream.py deleted file mode 100644 index ca24a78b..00000000 --- a/examples/jetstream.py +++ /dev/null @@ -1,69 +0,0 @@ -import asyncio -import nats -from nats.errors import TimeoutError - - -async def main(): - nc = await nats.connect("localhost") - - # Create JetStream context. - js = nc.jetstream() - - # Persist messages on 'foo's subject. - await js.add_stream(name="sample-stream", subjects=["foo"]) - - for i in range(0, 10): - ack = await js.publish("foo", f"hello world: {i}".encode()) - print(ack) - - # Create pull based consumer on 'foo'. - psub = await js.pull_subscribe("foo", "psub") - - # Fetch and ack messagess from consumer. - for i in range(0, 10): - msgs = await psub.fetch(1) - for msg in msgs: - print(msg) - - # Create single ephemeral push based subscriber. - sub = await js.subscribe("foo") - msg = await sub.next_msg() - await msg.ack() - - # Create single push based subscriber that is durable across restarts. - sub = await js.subscribe("foo", durable="myapp") - msg = await sub.next_msg() - await msg.ack() - - # Create deliver group that will be have load balanced messages. - async def qsub_a(msg): - print("QSUB A:", msg) - await msg.ack() - - async def qsub_b(msg): - print("QSUB B:", msg) - await msg.ack() - await js.subscribe("foo", "workers", cb=qsub_a) - await js.subscribe("foo", "workers", cb=qsub_b) - - for i in range(0, 10): - ack = await js.publish("foo", f"hello world: {i}".encode()) - print("\t", ack) - - # Create ordered consumer with flow control and heartbeats - # that auto resumes on failures. - osub = await js.subscribe("foo", ordered_consumer=True) - data = bytearray() - - while True: - try: - msg = await osub.next_msg() - data.extend(msg.data) - except TimeoutError: - break - print("All data in stream:", len(data)) - - await nc.close() - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/examples/kv.py b/examples/kv.py deleted file mode 100644 index b32cc83a..00000000 --- a/examples/kv.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio -import nats - - -async def main(): - nc = await nats.connect() - js = nc.jetstream() - - # Create a KV - kv = await js.create_key_value(bucket='MY_KV') - - # Set and retrieve a value - await kv.put('hello', b'world') - entry = await kv.get('hello') - print(f'KeyValue.Entry: key={entry.key}, value={entry.value}') - # KeyValue.Entry: key=hello, value=world - - await nc.close() - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/nats/__init__.py b/nats/__init__.py index 32d01162..fefe5b35 100644 --- a/nats/__init__.py +++ b/nats/__init__.py @@ -15,7 +15,7 @@ from typing import List, Union from .aio.client import Client as NATS - +from .aio.msg import Msg async def connect( servers: Union[str, List[str]] = ["nats://localhost:4222"], diff --git a/nats/jetstream/__init__.py b/nats/jetstream/__init__.py new file mode 100644 index 00000000..cbc5e3ce --- /dev/null +++ b/nats/jetstream/__init__.py @@ -0,0 +1,5 @@ +from nats.jetstream.context import Context +from nats.aio.client import Client + +def new(client: Client) -> Context: + return Context(client) diff --git a/nats/jetstream/api.py b/nats/jetstream/api.py new file mode 100644 index 00000000..35745a16 --- /dev/null +++ b/nats/jetstream/api.py @@ -0,0 +1,96 @@ +import nats.aio.client +import nats.aio.msg + +import json +from typing import Optional, Any, Dict + +DEFAULT_PREFIX = "$JS.API" + +# Error codes +JETSTREAM_NOT_ENABLED_FOR_ACCOUNT = 10039 +JETSTREAM_NOT_ENABLED = 10076 +STREAM_NOT_FOUND = 10059 +STREAM_NAME_IN_USE = 10058 +CONSUMER_CREATE = 10012 +CONSUMER_NOT_FOUND = 10014 +CONSUMER_NAME_EXISTS = 10013 +CONSUMER_ALREADY_EXISTS = 10105 +CONSUMER_EXISTS = 10148 +DUPLICATE_FILTER_SUBJECTS = 10136 +OVERLAPPING_FILTER_SUBJECTS = 10138 +CONSUMER_EMPTY_FILTER = 10139 +CONSUMER_DOES_NOT_EXIST = 10149 +MESSAGE_NOT_FOUND = 10037 +BAD_REQUEST = 10003 +STREAM_WRONG_LAST_SEQUENCE = 10071 + +# TODO: What should we call this error type? +class JetStreamError(Exception): + code:str + description: str + + def __init__(self, code: str, description: str) -> None: + self.code = code + self.description = description + + def __str__(self) -> str: + return ( + f"nats: {type(self).__name__}: code={self.code} " + f"description='{self.description}'" + ) + +class Client: + """ + Provides methods for sending requests and processing responses via JetStream. + """ + + def __init__( + self, + inner: nats.aio.client.Client, + timeout: float = 2.0, + prefix: str = DEFAULT_PREFIX + ) -> None: + self.inner = inner + self.timeout = timeout + self.prefix = prefix + + async def request( + self, + subject: str, + payload: bytes, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None + ) -> nats.aio.msg.Msg: + if timeout is None: + timeout = self.timeout + + return await self.inner.request(subject, payload, timeout=timeout) + + # TODO return `jetstream.Msg` + async def request_msg( + self, + subject: str, + payload: bytes, + timeout: Optional[float] = None, + ) -> nats.aio.msg.Msg: + return await self.inner.request(subject, payload, timeout=timeout or self.timeout) + + async def request_json( + self, subject: str, data: Any, + timeout: float | None, + ) -> Dict[str, Any]: + request_subject = f"{self.prefix}.{subject}" + request_data = json.dumps(data).encode("utf-8") + response = await self.inner.request( + request_subject, request_data, timeout or self.timeout + ) + + response_data = json.loads(response.data.decode("utf-8")) + response_error = response_data.get("error") + if response_error: + raise JetStreamError( + code=response_error["err_code"], + description=response_error["description"], + ) + + return response_data diff --git a/nats/jetstream/consumer.py b/nats/jetstream/consumer.py new file mode 100644 index 00000000..9bc1bac0 --- /dev/null +++ b/nats/jetstream/consumer.py @@ -0,0 +1,625 @@ +from __future__ import annotations + +import random +import hashlib +import string + +from datetime import datetime +from enum import Enum +from typing import AsyncIterable, Optional, Literal, List, Protocol, Dict, Any, AsyncIterator, AsyncIterable +from dataclasses import dataclass, field + +from nats.jetstream.api import CONSUMER_NOT_FOUND, Client, JetStreamError + +CONSUMER_CREATE_ACTION = "create" +CONSUMER_UPDATE_ACTION = "update" +CONSUMER_CREATE_OR_UPDATE_ACTION = "" + +class DeliverPolicy(Enum): + """ + DeliverPolicy determines from which point to start delivering messages. + """ + ALL = "all" + """DeliverAllPolicy starts delivering messages from the very beginning of a stream.""" + + LAST = "last" + """DeliverLastPolicy will start the consumer with the last sequence received.""" + + NEW = "new" + """DeliverNewPolicy will only deliver new messages that are sent after the consumer is created.""" + + BY_START_SEQUENCE = "by_start_sequence" + """DeliverByStartSequencePolicy will deliver messages starting from a given sequence configured with OptStartSeq.""" + + BY_START_TIME = "by_start_time" + """DeliverByStartTimePolicy will deliver messages starting from a given time configured with OptStartTime.""" + + LAST_PER_SUBJECT = "last_per_subject" + """DeliverLastPerSubjectPolicy will start the consumer with the last message for all subjects received.""" + +class AckPolicy(Enum): + """ + AckPolicy determines how the consumer should acknowledge delivered messages. + """ + NONE = "none" + """AckNonePolicy requires no acks for delivered messages.""" + + ALL = "all" + """AckAllPolicy when acking a sequence number, this implicitly acks all sequences below this one as well.""" + + EXPLICIT = "explicit" + """AckExplicitPolicy requires ack or nack for all messages.""" + + +class ReplayPolicy(Enum): + """ + ReplayPolicy determines how the consumer should replay messages it + already has queued in the stream. + """ + INSTANT = "instant" + """ReplayInstantPolicy will replay messages as fast as possible.""" + + ORIGINAL = "original" + """ReplayOriginalPolicy will maintain the same timing as the messages were received.""" + + +@dataclass +class SequenceInfo: + """ + SequenceInfo has both the consumer and the stream sequence and last activity. + """ + consumer: int + """Consumer sequence number.""" + + stream: int + """Stream sequence number.""" + + last_active: Optional[datetime] = None + """Last activity timestamp.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> SequenceInfo: + return cls( + consumer=data['consumer_seq'], + stream=data['stream_seq'], + last_active=datetime.fromtimestamp(data['last_active']) if data.get('last_active') else None + ) + +@dataclass +class PeerInfo: + """ + PeerInfo shows information about the peers in the cluster that are + supporting the stream or consumer. + """ + + name: str + """The server name of the peer.""" + + current: bool + """Indicates if the peer is up to date and synchronized with the leader.""" + + active: int + """The duration since this peer was last seen.""" + + offline: Optional[bool] = None + """Indicates if the peer is considered offline by the group.""" + + lag: Optional[int] = None + """The number of uncommitted operations this peer is behind the leader.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> PeerInfo: + return cls( + name=data['name'], + current=data['current'], + active=data['active'], + offline=data.get('offline', None), + lag=data.get('lag', None) + ) + +@dataclass +class ClusterInfo: + """ + ClusterInfo shows information about the underlying set of servers that + make up the stream or consumer. + """ + + name: Optional[str] = None + """Name is the name of the cluster.""" + + leader: Optional[str] = None + """Leader is the server name of the RAFT leader.""" + + replicas: List[PeerInfo] = field( + default_factory=list + ) + """Replicas is the list of members of the RAFT cluster.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> ClusterInfo: + return cls( + name=data.get('name'), + leader=data.get('leader'), + replicas=[PeerInfo.from_dict(replica) for replica in data.get('replicas', [])] + ) + +@dataclass +class ConsumerConfig: + """ + ConsumerConfig is the configuration of a JetStream consumer. + """ + name: Optional[str] = None + """Optional name for the consumer.""" + + durable: Optional[str] = None + """Optional durable name for the consumer.""" + + description: Optional[str] = None + """Optional description of the consumer.""" + + deliver_policy: Optional[DeliverPolicy] = None + """Defines from which point to start delivering messages from the stream. Defaults to DeliverAllPolicy.""" + + opt_start_seq: Optional[int] = None + """Optional sequence number from which to start message delivery.""" + + opt_start_time: Optional[datetime] = None + """Optional time from which to start message delivery.""" + + ack_policy: Optional[AckPolicy] = None + """Defines the acknowledgement policy for the consumer. Defaults to AckExplicitPolicy.""" + + ack_wait: Optional[int] = None + """How long the server will wait for an acknowledgement before resending a message.""" + + max_deliver: Optional[int] = None + """Maximum number of delivery attempts for a message.""" + + backoff: Optional[List[int]] = None + """Optional back-off intervals for retrying message delivery after a failed acknowledgement.""" + + filter_subject: Optional[str] = None + """Can be used to filter messages delivered from the stream.""" + + replay_policy: Optional[ReplayPolicy] = None + """Defines the rate at which messages are sent to the consumer.""" + + rate_limit: Optional[int] = None + """Optional maximum rate of message delivery in bits per second.""" + + sample_frequency: Optional[str] = None + """Optional frequency for sampling how often acknowledgements are sampled for observability.""" + + max_waiting: Optional[int] = None + """Maximum number of pull requests waiting to be fulfilled.""" + + max_ack_pending: Optional[int] = None + """Maximum number of outstanding unacknowledged messages.""" + + headers_only: Optional[bool] = None + """Indicates whether only headers of messages should be sent.""" + + max_request_batch: Optional[int] = None + """Optional maximum batch size a single pull request can make.""" + + max_request_expires: Optional[int] = None + """Maximum duration a single pull request will wait for messages to be available to pull.""" + + max_request_max_bytes: Optional[int] = None + """Optional maximum total bytes that can be requested in a given batch.""" + + inactive_threshold: Optional[int] = None + """Duration which instructs the server to clean up the consumer if it has been inactive.""" + + replicas: Optional[int] = None + """Number of replicas for the consumer's state.""" + + memory_storage: Optional[bool] = None + """Flag to force the consumer to use memory storage.""" + + filter_subjects: Optional[List[str]] = None + """Allows filtering messages from a stream by subject.""" + + metadata: Optional[Dict[str, str]] = None + """Set of application-defined key-value pairs for associating metadata on the consumer.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> ConsumerConfig: + return cls( + name=data.get('name'), + durable=data.get('durable'), + description=data.get('description'), + deliver_policy=DeliverPolicy(data.get('deliver_policy')) if data.get('deliver_policy') else None, + opt_start_seq=data.get('opt_start_seq'), + opt_start_time=datetime.fromisoformat(data['opt_start_time']) if data.get('opt_start_time') else None, + ack_policy=AckPolicy(data.get('ack_policy')) if data.get('ack_policy') else None, + ack_wait=data.get('ack_wait'), + max_deliver=data.get('max_deliver'), + backoff=data.get('backoff'), + filter_subject=data.get('filter_subject'), + replay_policy=ReplayPolicy(data.get('replay_policy')) if data.get('replay_policy') else None, + rate_limit=data.get('rate_limit'), + sample_frequency=data.get('sample_frequency'), + max_waiting=data.get('max_waiting'), + max_ack_pending=data.get('max_ack_pending'), + headers_only=data.get('headers_only'), + max_request_batch=data.get('max_request_batch'), + max_request_expires=data.get('max_request_expires'), + max_request_max_bytes=data.get('max_request_max_bytes'), + inactive_threshold=data.get('inactive_threshold'), + replicas=data.get('replicas'), + memory_storage=data.get('memory_storage'), + filter_subjects=data.get('filter_subjects'), + metadata=data.get('metadata') + ) + + def to_dict(self) -> Dict[str, Any]: + return {key: value for key, value in { + 'name': self.name, + 'durable_name': self.durable, + 'description': self.description, + 'deliver_policy': self.deliver_policy, + 'opt_start_seq': self.opt_start_seq, + 'opt_start_time': self.opt_start_time, + 'ack_policy': self.ack_policy.value if self.ack_policy else None, + 'ack_wait': self.ack_wait, + 'max_deliver': self.max_deliver, + 'backoff': self.backoff, + 'filter_subject': self.filter_subject, + 'replay_policy': self.replay_policy, + 'rate_limit': self.rate_limit, + 'sample_frequency': self.sample_frequency, + 'max_waiting': self.max_waiting, + 'max_ack_pending': self.max_ack_pending, + 'headers_only': self.headers_only, + 'max_request_batch': self.max_request_batch, + 'max_request_expires': self.max_request_expires, + 'max_request_max_bytes': self.max_request_max_bytes, + 'inactive_threshold': self.inactive_threshold, + 'replicas': self.replicas, + 'memory_storage': self.memory_storage, + 'filter_subjects': self.filter_subjects, + 'metadata': self.metadata + }.items() if value is not None} + + +@dataclass +class ConsumerInfo: + """ + ConsumerInfo is the detailed information about a JetStream consumer. + """ + name: str + """Unique identifier for the consumer.""" + + stream_name: str + """Name of the stream that the consumer is bound to.""" + + created: datetime + """Timestamp when the consumer was created.""" + + config: ConsumerConfig + """Configuration settings of the consumer.""" + + delivered: SequenceInfo + """Information about the most recently delivered message.""" + + ack_floor: SequenceInfo + """Indicates the message before the first unacknowledged message.""" + + num_ack_pending: int + """Number of messages that have been delivered but not yet acknowledged.""" + + num_redelivered: int + """Counts the number of messages that have been redelivered and not yet acknowledged.""" + + num_waiting: int + """Count of active pull requests.""" + + num_pending: int + """Number of messages that match the consumer's filter but have not been delivered yet.""" + + timestamp: datetime + """Timestamp when the info was gathered by the server.""" + + push_bound: bool + """Indicates whether at least one subscription exists for the delivery subject of this consumer.""" + + cluster: Optional[ClusterInfo] = None + """Information about the cluster to which this consumer belongs.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> ConsumerInfo: + return cls( + name=data['name'], + stream_name=data['stream_name'], + created=data['created'], + config=ConsumerConfig.from_dict(data['config']), + delivered=SequenceInfo.from_dict(data['delivered']), + ack_floor=SequenceInfo.from_dict(data['ack_floor']), + num_ack_pending=data['num_ack_pending'], + num_redelivered=data['num_redelivered'], + num_waiting=data['num_waiting'], + num_pending=data['num_pending'], + timestamp=datetime.fromisoformat(data['ts']), + push_bound=data.get('push_bound', False), + cluster=ClusterInfo.from_dict(data['cluster']) if 'cluster' in data else None + ) + +@dataclass +class OrderedConsumerConfig: + filter_subjects: List[str] = field(default_factory=list) + deliver_policy: Optional[DeliverPolicy] = None + opt_start_seq: Optional[int] = None + opt_start_time: Optional[datetime] = None + replay_policy: Optional[ReplayPolicy] = None + inactive_threshold: int = 5_000_000_000 # 5 seconds in nanoseconds + headers_only: bool = False + max_reset_attempts: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + def convert(value): + if isinstance(value, Enum): + return value.name + elif isinstance(value, datetime): + return value.isoformat() + return value + + result = { + "filter_subjects": self.filter_subjects, + "deliver_policy": self.deliver_policy.value if self.deliver_policy else None, + "replay_policy": self.replay_policy.value if self.replay_policy else None, + "headers_only": self.headers_only, + "inactive_threshold": self.inactive_threshold, + } + + if self.opt_start_seq is not None: + result["opt_start_seq"] = self.opt_start_seq + if self.opt_start_time is not None: + result["opt_start_time"] = self.opt_start_time + if self.max_reset_attempts is not None: + result["max_reset_attempts"] = self.max_reset_attempts + + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> OrderedConsumerConfig: + kwargs = data.copy() + return cls(**kwargs) + +class ConsumerNotFoundError(Exception): + pass + +class ConsumerNameRequiredError(ValueError): + pass + +class InvalidConsumerNameError(ValueError): + pass + +class ConsumerExistsError(Exception): + pass + +class ConsumerMultipleFilterSubjectsNotSupportedError(Exception): + pass + +class Consumer(Protocol): + @property + def cached_info(self) -> ConsumerInfo: + ... + +class MessageBatch: + pass + +class PullConsumer(Consumer): + def __init__(self, client: Client, stream: str, name: str, info: ConsumerInfo): + self._client = client + self._stream = stream + self._name = name + self._cached_info = info + + @property + def cached_info(self) -> ConsumerInfo: + return self._cached_info + + + def fetch_bytes(self, max_bytes: int) -> MessageBatch: + return MessageBatch() + + + def _fetch(self) -> MessageBatch: + return MessageBatch() + +class ConsumerInfoLister(AsyncIterable): + def __init__(self, client: Client) -> None: + self._client = client + + def __aiter__(self) -> AsyncIterator[ConsumerInfo]: + raise NotImplementedError + +class ConsumerNameLister(AsyncIterable): + def __aiter__(self) -> AsyncIterator[str]: + raise NotImplementedError + +class StreamConsumerManager(Protocol): + async def create_consumer(self, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Creates a consumer on a given stream with given config. If consumer already exists + and the provided configuration differs from its configuration, ErrConsumerExists is raised. + If the provided configuration is the same as the existing consumer, the existing consumer + is returned. Consumer interface is returned, allowing to operate on a consumer (e.g. fetch messages). + """ + ... + + async def update_consumer(self, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Updates an existing consumer. If consumer does not exist, ErrConsumerDoesNotExist is raised. + Consumer interface is returned, allowing to operate on a consumer (e.g. fetch messages). + """ + ... + + async def create_or_update_consumer(self, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Creates a consumer on a given stream with given config. If consumer already exists, + it will be updated (if possible). Consumer interface is returned, allowing to operate + on a consumer (e.g. fetch messages). + """ + ... + + async def consumer(self, stream: str, consumer: str, timeout: Optional[float] = None) -> Consumer: + """ + Returns an interface to an existing consumer, allowing processing of messages. + If consumer does not exist, ErrConsumerNotFound is raised. + """ + ... + + async def delete_consumer(self, stream: str, consumer: str, timeout: Optional[float] = None) -> None: + """ + Removes a consumer with given name from a stream. + If consumer does not exist, `ConsumerNotFoundError` is raised. + """ + ... + +class ConsumerManager(Protocol): + async def create_consumer(self, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Creates a consumer on a given stream with given config. If consumer already exists + and the provided configuration differs from its configuration, `ConsumerExists` is raised. + If the provided configuration is the same as the existing consumer, the existing consumer + is returned. + """ + ... + + async def update_consumer(self, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Updates an existing consumer. If consumer does not exist, `ConsumerNotFound` is raised. + Consumer interface is returned, allowing to operate on a consumer (e.g. fetch messages). + """ + ... + + async def create_or_update_consumer(self, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + """ + Creates a consumer on a given stream with given config. If consumer already exists, + it will be updated (if possible). + """ + ... + + async def consumer(self, consumer: str, timeout: Optional[float] = None) -> Consumer: + """ + Returns an interface to an existing consumer, allowing processing of messages. + + If the consumer does not exist, `ConsumerNotFoundError` is raised. + """ + ... + + async def delete_consumer(self, consumer: str, timeout: Optional[float] = None) -> None: + """ + Removes a consumer with given name from a stream. + + If the consumer does not exist, `ConsumerNotFoundError` is raised. + """ + ... + + def list_consumers(self) -> ConsumerInfoLister: + """ + Returns ConsumerInfoLister enabling iterating over a channel of consumer infos. + """ + ... + + def consumer_names(self) -> ConsumerNameLister: + """ + Returns a ConsumerNameLister enabling iterating over a channel of consumer names. + """ + ... + + +def _generate_consumer_name() -> str: + name = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + sha = hashlib.sha256(name.encode()).digest() + return ''.join(string.ascii_lowercase[b % 26] for b in sha[:8]) + +async def _upsert_consumer(client: Client, stream: str, config: ConsumerConfig, action: str, timeout: Optional[float] = None) -> Consumer: + consumer_name = config.name + if not consumer_name: + if config.durable: + consumer_name = config.durable + else: + consumer_name = _generate_consumer_name() + + _validate_consumer_name(consumer_name) + + if config.filter_subject and not config.filter_subjects: + create_consumer_subject = f"CONSUMER.CREATE.{stream}.{consumer_name}.{config.filter_subject}" + else: + create_consumer_subject = f"CONSUMER.CREATE.{stream}.{consumer_name}" + + create_consumer_request = { + 'stream_name': stream, + 'config': config.to_dict(), + 'action': action + } + + create_consumer_response = await client.request_json(create_consumer_subject, create_consumer_request, timeout=timeout) + + info = ConsumerInfo.from_dict(create_consumer_response) + if config.filter_subjects and not info.config.filter_subjects: + raise ConsumerMultipleFilterSubjectsNotSupportedError() + + # TODO support more than just pull consumers + return PullConsumer( + client=client, + name=consumer_name, + stream=stream, + info=info, + ) + +async def _create_consumer(client: Client, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + return await _upsert_consumer(client, stream=stream, config=config, action=CONSUMER_CREATE_ACTION, timeout=timeout) + +async def _update_consumer(client: Client, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + return await _upsert_consumer(client, stream=stream, config=config, action=CONSUMER_UPDATE_ACTION, timeout=timeout) + +async def _create_or_update_consumer(client: Client, stream: str, config: ConsumerConfig, timeout: Optional[float] = None) -> Consumer: + return await _upsert_consumer(client, stream=stream, config=config, action=CONSUMER_CREATE_OR_UPDATE_ACTION, timeout=timeout) + + +async def _get_consumer(client: Client, stream: str, name: str, timeout: Optional[float] = None) -> 'Consumer': + _validate_consumer_name(name) + consumer_info_request = {} + consumer_info_subject = f"CONSUMER.INFO.{stream}.{name}" + + try: + consumer_info_response = await client.request_json(consumer_info_subject, consumer_info_request, timeout=timeout) + except JetStreamError as jetstream_error: + if jetstream_error.code == CONSUMER_NOT_FOUND: + raise ConsumerNotFoundError from jetstream_error + + raise jetstream_error + + info = ConsumerInfo.from_dict(consumer_info_response) + + return PullConsumer( + client=client, + stream=stream, + name=name, + info=info, + ) + +async def _delete_consumer(client: Client, stream: str, consumer: str, timeout: Optional[float] = None) -> None: + _validate_consumer_name(consumer) + + delete_consumer_request = {} + delete_consumer_subject = f"CONSUMER.DELETE.{stream}.{consumer}" + + try: + delete_response = await client.request_json(delete_consumer_subject, delete_consumer_request, timeout=timeout) + except JetStreamError as jetstream_error: + if jetstream_error.code == CONSUMER_NOT_FOUND: + raise ConsumerNotFoundError() + + raise jetstream_error + +def _validate_consumer_name(name: str) -> None: + if not name: + raise ConsumerNameRequiredError() + + if any(c in name for c in ">*. /\\"): + raise InvalidConsumerNameError() diff --git a/nats/jetstream/context.py b/nats/jetstream/context.py new file mode 100644 index 00000000..2d4bd493 --- /dev/null +++ b/nats/jetstream/context.py @@ -0,0 +1,227 @@ +# Copyright 2016-2024 The NATS Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Type, TypeVar + +from nats.aio.client import Client as NATS +from typing import Optional + +from nats.errors import NoRespondersError + +from .api import * +from .stream import (Stream, StreamConfig, StreamInfo, StreamInfoLister, StreamManager, StreamNameAlreadyInUseError, StreamNameLister, StreamNotFoundError, StreamSourceMultipleFilterSubjectsNotSupported, StreamSourceNotSupportedError, StreamSubjectTransformNotSupportedError, _validate_stream_name) +from .publisher import (NoStreamResponseError, Publisher, PublishAck) +from .consumer import * + +class Context( + # Publisher, + StreamManager, + # StreamConsumerManager, +): + """ + Provides a context for interacting with JetStream. + The capabilities of JetStream include: + + - Publishing messages to a stream using `Publisher`. + - Managing streams using the `StreamManager` protocol. + - Managing consumers using the `StreamConsumerManager` protocol. + """ + + def __init__(self, nats: NATS, timeout: float = 2.0): + self._client = Client( + nats, + timeout=timeout, + ) + + async def publish(self, subject: str, payload: bytes, headers: Optional[Dict] = None, timeout: Optional[float] = None) -> PublishAck: + try: + response = await self._client.request(subject, payload, timeout) + except NoRespondersError as no_responders_error: + raise NoStreamResponseError from no_responders_error + + response_data = json.loads(response.data) + response_error = response_data.get("error") + if response_error: + raise JetStreamError( + code=response_error["err_code"], + description=response_error["description"], + ) + + return PublishAck.from_dict(response_data) + + async def create_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: + """ + Creates a new stream with given config. + """ + + stream_create_subject = f"STREAM.CREATE.{config.name}" + stream_create_request = config.to_dict() + try: + stream_create_response = await self._client.request_json( + stream_create_subject, + stream_create_request, + timeout=timeout + ) + except JetStreamError as jetstream_error: + if jetstream_error.code == STREAM_NAME_IN_USE: + raise StreamNameAlreadyInUseError() from jetstream_error + + raise jetstream_error + + info = StreamInfo.from_dict(stream_create_response) + + # Check if subject transforms are supported + if config.subject_transform and not info.config.subject_transform: + raise StreamSubjectTransformNotSupportedError() + + # Check if sources and subject transforms are supported + if config.sources: + if not info.config.sources: + raise StreamSourceNotSupportedError() + + for i in range(len(config.sources)): + source = config.sources[i] + response_source = config.sources[i] + + if source.subject_transforms and not response_source.subject_transforms: + raise StreamSourceMultipleFilterSubjectsNotSupported() + + return Stream( + client=self._client, + name=info.config.name, + info=info, + ) + + async def update_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: + """ + Updates an existing stream with the given config. + """ + + stream_create_subject = f"STREAM.UPDATE.{config.name}" + stream_create_request = config.to_dict() + try: + stream_create_response = await self._client.request_json( + stream_create_subject, + stream_create_request, + timeout=timeout + ) + except JetStreamError as jetstream_error: + if jetstream_error.code == STREAM_NAME_IN_USE: + raise StreamNameAlreadyInUseError() from jetstream_error + + if jetstream_error.code == STREAM_NOT_FOUND: + raise StreamNotFoundError() from jetstream_error + + raise jetstream_error + + info = StreamInfo.from_dict(stream_create_response) + + # Check if subject transforms are supported + if config.subject_transform and not info.config.subject_transform: + raise StreamSubjectTransformNotSupportedError() + + # Check if sources and subject transforms are supported + if config.sources: + if not info.config.sources: + raise StreamSourceNotSupportedError() + + for i in range(len(config.sources)): + source = config.sources[i] + response_source = config.sources[i] + + if source.subject_transforms and not response_source.subject_transforms: + raise StreamSourceMultipleFilterSubjectsNotSupported() + + return Stream( + client=self._client, + name=info.config.name, + info=info, + ) + + async def create_or_update_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: + """Creates a stream with given config or updates it if it already exists.""" + try: + return await self.update_stream(config, timeout=timeout) + except StreamNotFoundError: + return await self.create_stream(config, timeout=timeout) + + async def stream( + self, name: str, timeout: Optional[float] = None + ) -> Stream: + """Fetches `StreamInfo` and returns a `Stream` instance for a given stream name.""" + _validate_stream_name(name) + + stream_info_subject = f"STREAM.INFO.{name}" + stream_info_request = {} + try: + stream_info_response = await self._client.request_json( + stream_info_subject, + stream_info_request, + timeout=timeout + ) + except JetStreamError as jetstream_error: + if jetstream_error.code == STREAM_NOT_FOUND: + raise StreamNotFoundError() from jetstream_error + + raise jetstream_error + + info = StreamInfo.from_dict(stream_info_response) + + return Stream( + client=self._client, + name=info.config.name, + info=info, + ) + + async def stream_name_by_subject( + self, subject: str, timeout: Optional[float] = None + ) -> str: + """Returns a stream name listening on a given subject.""" + raise NotImplementedError + + async def delete_stream( + self, name: str, timeout: Optional[float] = None + ) -> None: + """Removes a stream with given name.""" + _validate_stream_name(name) + + stream_delete_subject = f"STREAM.DELETE.{name}" + stream_delete_request = {} + try: + stream_delete_response = await self._client.request_json( + stream_delete_subject, + stream_delete_request, + timeout=timeout + ) + except JetStreamError as response_error: + if response_error.code == STREAM_NOT_FOUND: + raise StreamNotFoundError() from response_error + + raise response_error + + def list_streams(self, + timeout: Optional[float] = None + ) -> StreamInfoLister: + raise NotImplementedError + + + def stream_names(self, + timeout: Optional[float] = None) -> StreamNameLister: + raise NotImplementedError diff --git a/nats/jetstream/publisher.py b/nats/jetstream/publisher.py new file mode 100644 index 00000000..34bb5024 --- /dev/null +++ b/nats/jetstream/publisher.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from asyncio import Future +from typing import Dict, Any, Protocol, Optional +from dataclasses import dataclass + + +@dataclass +class PublishAck: + """ + Represents the response of publishing a message to JetStream. + """ + + stream: str + """ + The stream name the message was published to. + """ + + sequence: int + """ + The stream sequence number of the message. + """ + + domain: Optional[str] = None + """ + The domain the message was published to. + """ + + duplicate: Optional[bool] = None + """ + Indicates whether the message was a duplicate. + """ + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> PublishAck: + return cls( + stream=data["stream"], + sequence=data["seq"], + domain=data.get("domain"), + duplicate=data.get("duplicate"), + ) + +class NoStreamResponseError(Exception): + """ + Raised when no response is received from the JetStream server. + """ + pass + +class Publisher(Protocol): + """ + A protocol for publishing messages to a stream. + """ + + async def publish(self, subject: str, payload: bytes) -> PublishAck: + """ + Publishes a message with the given payload on the given subject. + """ + ... + + async def publish_async( + self, + subject: str, + payload: bytes = b'', + wait_stall: Optional[float] = None, + ) -> Future[PublishAck]: + """ + Publishes a message with the given payload on the given subject without waiting for a server acknowledgement. + """ + ... diff --git a/nats/jetstream/stream.py b/nats/jetstream/stream.py new file mode 100644 index 00000000..699a1ef9 --- /dev/null +++ b/nats/jetstream/stream.py @@ -0,0 +1,901 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from dataclasses import dataclass, field +from typing import AsyncIterable, Dict, Any, Optional, List, AsyncIterator, Protocol + +from nats import jetstream + +from .api import ( + STREAM_NAME_IN_USE, + STREAM_NOT_FOUND, + Client, + JetStreamError, + JetStreamError, +) + +from .consumer import ( + ClusterInfo, + ConsumerConfig, + Consumer, + _create_consumer, + _delete_consumer, + _get_consumer, + _update_consumer, + _create_or_update_consumer, +) + +class RetentionPolicy(Enum): + """ + Determines how messages in a stream are retained. + """ + + LIMITS = "limits" + """LimitsPolicy means that messages are retained until any given limit is reached. This could be one of MaxMsgs, MaxBytes, or MaxAge.""" + + INTEREST = "interest" + """InterestPolicy specifies that when all known observables have acknowledged a message, it can be removed.""" + + WORKQUEUE = "workqueue" + """WorkQueuePolicy specifies that when the first worker or subscriber acknowledges the message, it can be removed.""" + + +class DiscardPolicy(Enum): + """ + Determines how to proceed when limits of messages or bytes are reached. + """ + + OLD = "old" + """DiscardOld will remove older messages to return to the limits. + + This is the default. + """ + + NEW = "new" + """DiscardNew will fail to store new messages once the limits are reached.""" + + +class StorageType(Enum): + """ + Determines how messages are stored for retention. + """ + + FILE = "file" + """ + Specifies on disk storage. + + This is the default. + """ + + MEMORY = "memory" + """ + Specifies in-memory storage. + """ + + +class StoreCompression(Enum): + """ + Determines how messages are compressed. + """ + + NONE = "none" + """ + Disables compression on the stream. + + This is the default. + """ + + S2 = "s2" + """ + Enables S2 compression on the stream. + """ + + +@dataclass +class SubjectTransformConfig: + """ + SubjectTransformConfig is for applying a subject transform (to matching + messages) before doing anything else when a new message is received. + """ + + source: str + """The subject pattern to match incoming messages against.""" + + destination: str + """The subject pattern to remap the subject to.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> SubjectTransformConfig: + return cls(source=data["src"], destination=data["dest"]) + + def to_dict(self) -> Dict[str, str]: + return {"src": self.source, "dest": self.destination} + + +@dataclass +class StreamSourceInfo: + """ + StreamSourceInfo shows information about an upstream stream source/mirror. + """ + + name: str + """Name is the name of the stream that is being replicated.""" + + lag: Optional[int] = None + """Lag informs how many messages behind the source/mirror operation is. This will only show correctly if there is active communication with stream/mirror.""" + + active: Optional[int] = None + """Active informs when last the mirror or sourced stream had activity. Value will be -1 when there has been no activity.""" + + filter_subject: Optional[str] = None + """FilterSubject is the subject filter defined for this source/mirror.""" + + subject_transforms: List[SubjectTransformConfig] = field(default_factory=list) + """SubjectTransforms is a list of subject transforms defined for this source/mirror.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamSourceInfo: + return cls( + name=data["name"], + lag=data.get("lag", None), + active=data.get("active", None), + filter_subject=data.get("filter_subject", None), + subject_transforms=[ + SubjectTransformConfig.from_dict(x) + for x in data.get("subject_transforms", []) + ], + ) + + +@dataclass +class StreamState: + """ + StreamState is the state of a JetStream stream at the time of request. + """ + + msgs: int + """The number of messages stored in the stream.""" + + bytes: int + """The number of bytes stored in the stream.""" + + first_sequence: int + """The the sequence number of the first message in the stream.""" + + first_time: datetime + """The timestamp of the first message in the stream.""" + + last_sequence: int + """The sequence number of the last message in the stream.""" + + last_time: datetime + """The timestamp of the last message in the stream.""" + + consumers: int + """The number of consumers on the stream.""" + + num_deleted: int + """NumDeleted is the number of messages that have been removed from the stream. Only deleted messages causing a gap in stream sequence numbers are counted. Messages deleted at the beginning or end of the stream are not counted.""" + + num_subjects: int + """NumSubjects is the number of unique subjects the stream has received messages on.""" + + deleted: Optional[List[int]] = None + """A list of sequence numbers that have been removed from the stream. This field will only be returned if the stream has been fetched with the DeletedDetails option.""" + + subjects: Optional[Dict[str, int]] = None + """Subjects is a map of subjects the stream has received messages on with message count per subject. This field will only be returned if the stream has been fetched with the SubjectFilter option.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamState: + return cls( + msgs=data["messages"], + bytes=data["bytes"], + first_sequence=data["first_seq"], + first_time=datetime.strptime(data["first_ts"], "%Y-%m-%dT%H:%M:%SZ"), + last_sequence=data["last_seq"], + last_time=datetime.strptime(data["last_ts"], "%Y-%m-%dT%H:%M:%SZ"), + consumers=data["consumer_count"], + num_deleted=data.get("num_deleted", 0), + num_subjects=data.get("num_subjects", 0), + deleted=data.get("deleted", None), + subjects=data.get("subjects", None), + ) + + +@dataclass +class Republish: + """ + RePublish is for republishing messages once committed to a stream. The + original subject is remapped from the subject pattern to the destination + pattern. + """ + + destination: str + """The subject pattern to republish the subject to.""" + + source: Optional[str] = None + """The subject pattern to match incoming messages against.""" + + headers_only: Optional[bool] = None + """A flag to indicate that only the headers should be republished.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Republish: + return cls( + destination=data["dest"], + source=data.get("src"), + headers_only=data.get("headers_only"), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + key: value + for key, value in { + "dest": self.destination, + "src": self.source, + "headers_only": self.headers_only, + }.items() + if value is not None + } + + +@dataclass +class Placement: + """ + Placement is used to guide placement of streams in clustered JetStream. + """ + + cluster: str + """The name of the cluster to which the stream should be assigned.""" + + tags: List[str] = field(default_factory=list) + """Tags are used to match streams to servers in the cluster. A stream will be assigned to a server with a matching tag.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Placement: + return cls( + cluster=data["cluster"], + tags=data.get("tags", []), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "cluster": self.cluster, + "tags": self.tags, + } + + +@dataclass +class ExternalStream: + """ + ExternalStream allows you to qualify access to a stream source in another + account. + """ + + api_prefix: str + """The subject prefix that imports the other account/domain $JS.API.CONSUMER.> subjects.""" + + deliver_prefix: str + """The delivery subject to use for the push consumer.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> ExternalStream: + return cls( + api_prefix=data["api"], + deliver_prefix=data["deliver"], + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "api": self.api_prefix, + "deliver": self.deliver_prefix, + } + + +@dataclass +class StreamSource: + """ + StreamSource dictates how streams can source from other streams. + """ + + name: str + """The name of the stream to source from.""" + + opt_start_seq: Optional[int] = None + """The sequence number to start sourcing from.""" + + opt_start_time: Optional[datetime] = None + """The timestamp of messages to start sourcing from.""" + + filter_subject: Optional[str] = None + """The subject filter used to only replicate messages with matching subjects.""" + + subject_transforms: Optional[List[SubjectTransformConfig]] = None + """ + A list of subject transforms to apply to matching messages. + + Subject transforms on sources and mirrors are also used as subject filters with optional transformations. + """ + + external: Optional[ExternalStream] = None + """A configuration referencing a stream source in another account or JetStream domain.""" + + domain: Optional[str] = None + """Used to configure a stream source in another JetStream domain. This setting will set the External field with the appropriate APIPrefix.""" + + def __post__init__(self): + if self.external and self.domain: + raise ValueError("cannot set both external and domain") + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamSource: + kwargs = data.copy() + + return cls( + name=data["name"], + opt_start_seq=data.get("opt_start_seq"), + opt_start_time=data.get("opt_start_time"), + filter_subject=data.get("filter_subject"), + subject_transforms=[ + SubjectTransformConfig.from_dict(subject_transform) + for subject_transform in data.get("subject_transforms", []) + ], + external=ExternalStream.from_dict(data["external"]) + if data.get("external") + else None, + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "opt_start_seq": self.opt_start_seq, + "opt_start_time": self.opt_start_time, + "filter_subject": self.filter_subject, + "subject_transforms": [ + subject_transform.to_dict() + for subject_transform in self.subject_transforms + ] + if self.subject_transforms + else None, + "external": self.external.to_dict() if self.external else None, + } + + +@dataclass +class StreamConsumerLimits: + """ + StreamConsumerLimits are the limits for a consumer on a stream. These can + be overridden on a per consumer basis. + """ + + inactive_threshold: Optional[int] = None + """A duration which instructs the server to clean up the consumer if it has been inactive for the specified duration.""" + + max_ack_pending: Optional[int] = None + """A maximum number of outstanding unacknowledged messages for a consumer.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamConsumerLimits: + return cls( + max_ack_pending=data.get("max_ack_pending"), + inactive_threshold=data.get("inactive_threshold"), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "max_ack_pending": self.max_ack_pending, + "inactive_threshold": self.inactive_threshold, + } + + +@dataclass +class StreamConfig: + """ + StreamConfig is the configuration of a JetStream stream. + """ + + name: str + """Name is the name of the stream. It is required and must be unique across the JetStream account. Names cannot contain whitespace, ., >, path separators (forward or backwards slash), and non-printable characters.""" + + description: Optional[str] = None + """Description is an optional description of the stream.""" + + subjects: List[str] = field(default_factory=list) + """Subjects is a list of subjects that the stream is listening on. Wildcards are supported. Subjects cannot be set if the stream is created as a mirror.""" + + retention: Optional[RetentionPolicy] = None + """Retention defines the message retention policy for the stream. Defaults to RetentionPolicy.LIMITS.""" + + max_consumers: Optional[int] = None + """MaxConsumers specifies the maximum number of consumers allowed for the stream.""" + + max_msgs: Optional[int] = None + """MaxMsgs is the maximum number of messages the stream will store. After reaching the limit, stream adheres to the discard policy. If not set, server default is -1 (unlimited).""" + + max_bytes: Optional[int] = None + """MaxBytes is the maximum total size of messages the stream will store. After reaching the limit, stream adheres to the discard policy. If not set, server default is -1 (unlimited).""" + + discard: Optional[DiscardPolicy] = None + """Discard defines the policy for handling messages when the stream reaches its limits in terms of number of messages or total bytes. Defaults to DiscardPolicy.OLD if not set""" + + discard_new_per_subject: Optional[bool] = None + """DiscardNewPerSubject is a flag to enable discarding new messages per subject when limits are reached. Requires DiscardPolicy to be DiscardNew and the MaxMsgsPerSubject to be set.""" + + max_age: Optional[int] = None + """MaxAge is the maximum age of messages that the stream will retain.""" + + max_msgs_per_subject: Optional[int] = None + """MaxMsgsPerSubject is the maximum number of messages per subject that the stream will retain.""" + + max_msg_size: Optional[int] = None + """MaxMsgSize is the maximum size of any single message in the stream.""" + + storage: StorageType = StorageType.FILE + """Storage specifies the type of storage backend used for the stream (file or memory). Defaults to StorageType.FILE """ + + replicas: int = 1 + """Replicas is the number of stream replicas in clustered JetStream. Defaults to 1, maximum is 5.""" + + no_ack: Optional[bool] = None + """NoAck is a flag to disable acknowledging messages received by this stream. If set to true, publish methods from the JetStream client will not work as expected, since they rely on acknowledgements. Core NATS publish methods should be used instead. Note that this will make message delivery less reliable.""" + + duplicates: Optional[int] = None + """Duplicates is the window within which to track duplicate messages. If not set, server default is 2 minutes.""" + + placement: Optional[Placement] = None + """Placement is used to declare where the stream should be placed via tags and/or an explicit cluster name.""" + + mirror: Optional[StreamSource] = None + """Mirror defines the configuration for mirroring another stream.""" + + sources: Optional[List[StreamSource]] = None + """Sources is a list of other streams this stream sources messages from.""" + + sealed: Optional[bool] = None + """Sealed streams do not allow messages to be published or deleted via limits or API, sealed streams cannot be unsealed via configuration update. Can only be set on already created streams via the Update API.""" + + deny_delete: Optional[bool] = None + """ + Restricts the ability to delete messages from a stream via the API. + + Server defaults to false when not set. + """ + + deny_purge: Optional[bool] = None + """Restricts the ability to purge messages from a stream via the API. + + Server defaults to false from server when not set. + """ + + allow_rollup: Optional[bool] = None + """Allows the use of the `Nats-Rollup` header to replace all contents of a stream, or subject in a stream, with a single new message. + """ + + compression: Optional[StoreCompression] = None + """ + Specifies the message storage compression algorithm. + + Server defaults to `StoreCompression.NONE` when not set. + """ + + first_sequence: Optional[int] = None + """The initial sequence number of the first message in the stream.""" + + subject_transform: Optional[SubjectTransformConfig] = None + """Allows applying a transformation to matching messages' subjects.""" + + republish: Optional[Republish] = None + """Allows immediate republishing of a message to the configured subject after it's stored.""" + + allow_direct: bool = False + """ + Enables direct access to individual messages using direct get. + + Server defaults to false. + """ + + mirror_direct: bool = False + """ + Enables direct access to individual messages from the origin stream. + + Defaults to false. + """ + + consumer_limits: Optional[StreamConsumerLimits] = None + """Defines limits of certain values that consumers can set, defaults for those who don't set these settings.""" + + metadata: Optional[Dict[str, str]] = None + """Provides a set of application-defined key-value pairs for associating metadata on the stream. + + Note: This feature requires nats-server v2.10.0 or later. + """ + + def __post_init__(self): + _validate_stream_name(self.name) + + if self.max_msgs_per_subject and not self.discard: + raise ValueError("max_msgs_per_subject requires discard policy to be set") + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamConfig: + return cls( + name=data["name"], + description=data.get("description"), + subjects=data.get("subjects", []), + retention=RetentionPolicy(data["retention"]) + if data.get("retention") + else None, + max_consumers=data.get("max_consumers"), + max_msgs=data.get("max_msgs"), + max_bytes=data.get("max_bytes"), + discard=DiscardPolicy(data["discard"]) if data.get("discard") else None, + discard_new_per_subject=data.get("discard_new_per_subject"), + max_age=data.get("max_age"), + max_msgs_per_subject=data.get("max_msgs_per_subject"), + max_msg_size=data.get("max_msg_size"), + storage=StorageType(data["storage"]), + replicas=data.get("num_replicas", 1), + no_ack=data.get("no_ack"), + duplicates=data.get("duplicates"), + placement=Placement.from_dict(data["placement"]) + if data.get("placement") + else None, + mirror=StreamSource.from_dict(data["mirror"]) + if data.get("mirror") + else None, + sources=[ + StreamSource.from_dict(source) for source in data.get("sources", []) + ], + sealed=data.get("sealed"), + deny_delete=data.get("deny_delete"), + deny_purge=data.get("deny_purge"), + allow_rollup=data.get("allow_rollup"), + compression=StoreCompression(data["compression"]) + if data.get("compression") + else None, + first_sequence=data.get("first_sequence"), + subject_transform=SubjectTransformConfig.from_dict( + data["subject_transform"] + ) + if data.get("subject_transform") + else None, + republish=Republish.from_dict(data["republish"]) + if data.get("republish") + else None, + allow_direct=data.get("allow_direct", False), + mirror_direct=data.get("mirror_direct", False), + consumer_limits=StreamConsumerLimits.from_dict(data["consumer_limits"]) + if data.get("consumer_limits") + else None, + metadata=data.get("metadata"), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + k: v + for k, v in { + "name": self.name, + "description": self.description, + "subjects": self.subjects, + "retention": self.retention.value if self.retention else None, + "max_consumers": self.max_consumers, + "max_msgs": self.max_msgs, + "max_bytes": self.max_bytes, + "discard": self.discard.value if self.discard else None, + "discard_new_per_subject": self.discard_new_per_subject, + "max_age": self.max_age, + "max_msgs_per_subject": self.max_msgs_per_subject, + "max_msg_size": self.max_msg_size, + "storage": self.storage.value, + "num_replicas": self.replicas, + "no_ack": self.no_ack, + "duplicate_window": self.duplicates, + "placement": self.placement.to_dict() if self.placement else None, + "mirror": self.mirror.to_dict() if self.mirror else None, + "sources": [source.to_dict() for source in self.sources] + if self.sources + else None, + "sealed": self.sealed, + "deny_delete": self.deny_delete, + "deny_purge": self.deny_purge, + "allow_rollup": self.allow_rollup, + "compression": self.compression.value if self.compression else None, + "first_seq": self.first_sequence, + "subject_transform": self.subject_transform.to_dict() + if self.subject_transform + else None, + "republish": self.republish.to_dict() if self.republish else None, + "allow_direct": self.allow_direct, + "mirror_direct": self.mirror_direct, + "consumer_limits": self.consumer_limits.to_dict() + if self.consumer_limits + else None, + "metadata": self.metadata, + }.items() + if v is not None + } + + +@dataclass +class StreamInfo: + """ + Provides configuration and current state for a stream. + """ + + config: StreamConfig + """Contains the configuration settings of the stream, set when creating or updating the stream.""" + + timestamp: datetime + """Indicates when the info was gathered by the server.""" + + created: datetime + """The timestamp when the stream was created.""" + + state: StreamState + """Provides the state of the stream at the time of request, including metrics like the number of messages in the stream, total bytes, etc.""" + + cluster: Optional[ClusterInfo] = None + """Contains information about the cluster to which this stream belongs (if applicable).""" + + mirror: Optional[StreamSourceInfo] = None + """Contains information about another stream this one is mirroring. Mirroring is used to create replicas of another stream's data. This field is omitted if the stream is not mirroring another stream.""" + + sources: List[StreamSourceInfo] = field(default_factory=list) + """A list of source streams from which this stream collects data.""" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> StreamInfo: + return cls( + config=StreamConfig.from_dict(data["config"]), + timestamp=datetime.fromisoformat(data["ts"]), + created=datetime.fromisoformat(data["created"]), + state=StreamState.from_dict(data["state"]), + cluster=ClusterInfo.from_dict(data["cluster"]) + if "cluster" in data + else None, + mirror=StreamSourceInfo.from_dict(data["mirror"]) + if "mirror" in data + else None, + sources=[StreamSourceInfo.from_dict(source) for source in data["sources"]] + if "sources" in data + else [], + ) + + +class StreamNameAlreadyInUse(Exception): + """ + Raised when trying to create a stream with a name that is already in use. + """ + + pass + + +class StreamNotFoundError(Exception): + """ + Raised when trying to access a stream that does not exist. + """ + + pass + + +class StreamSourceNotSupportedError(Exception): + """ + Raised when a source stream is not supported by the server. + """ + + pass + + +class StreamSubjectTransformNotSupportedError(Exception): + """ + Raised when a subject transform is not supported by the server. + """ + + pass + + +class StreamSourceMultipleFilterSubjectsNotSupported(Exception): + """ + Raised when multiple filter subjects are not supported by the server. + """ + + pass + + +class InvalidStreamNameError(ValueError): + """ + Raised when an invalid stream name is provided. + """ + + pass + + +class StreamNameRequiredError(ValueError): + """ + Raised when a stream name is required but not provided (e.g empty). + """ + + pass + + +class StreamNameAlreadyInUseError(Exception): + """ + Raised when a stream name is already in use. + """ + + pass + + +class Stream: + def __init__(self, client: Client, name: str, info: StreamInfo) -> None: + self._client = client + self._name = name + self._cached_info = info + + @property + def cached_info(self) -> StreamInfo: + """ + Returns the cached `StreamInfo` for the stream. + """ + return self._cached_info + + async def info( + self, + subject_filter: Optional[str] = None, + deleted_details: Optional[bool] = None, + timeout: Optional[float] = None + ) -> StreamInfo: + """Returns `StreamInfo` from the server.""" + # TODO(caspervonb): handle pagination + stream_info_subject = f"STREAM.INFO.{self._name}" + stream_info_request = { + "subject_filter": subject_filter, + "deleted_details": deleted_details, + } + try: + info_response = await self._client.request_json( + stream_info_subject, stream_info_request, timeout=timeout + ) + except JetStreamError as jetstream_error: + if jetstream_error.code == STREAM_NOT_FOUND: + raise StreamNotFoundError() from jetstream_error + + raise jetstream_error + + info = StreamInfo.from_dict(info_response) + self._cached_info = info + + return info + + + # TODO(caspervonb): Go does not return anything for this operation, should we? + async def purge( + self, + sequence: Optional[int] = None, + keep: Optional[int] = None, + subject: Optional[str] = None, + timeout: Optional[float] = None + ) -> None: + """ + Removes messages from a stream. + This is a destructive operation. + """ + + # TODO(caspervonb): enforce types with overloads + if keep is not None and sequence is not None: + raise ValueError( + "both 'keep' and 'sequence' cannot be provided in purge request" + ) + + stream_purge_subject = f"STREAM.PURGE.{self._name}" + stream_purge_request = { + "sequence": sequence, + "keep": keep, + "subject": subject, + } + + try: + stream_purge_response = await self._client.request_json( + stream_purge_subject, stream_purge_request, timeout=timeout + ) + except JetStreamError as jetstream_error: + raise jetstream_error + + + async def create_consumer( + self, config: ConsumerConfig, timeout: Optional[float] = None + ) -> Consumer: + return await _create_consumer( + self._client, stream=self._name, config=config, timeout=timeout + ) + + async def update_consumer( + self, config: ConsumerConfig, timeout: Optional[float] = None + ) -> Consumer: + return await _update_consumer( + self._client, stream=self._name, config=config, timeout=timeout + ) + + async def create_or_update_consumer( + self, config: ConsumerConfig, timeout: Optional[float] = None + ) -> Consumer: + return await _create_or_update_consumer( + self._client, stream=self._name, config=config, timeout=timeout + ) + + async def consumer(self, name: str, timeout: Optional[float] = None) -> Consumer: + return await _get_consumer( + self._client, stream=self._name, name=name, timeout=timeout + ) + + async def delete_consumer(self, name: str, timeout: Optional[float] = None) -> None: + return await _delete_consumer( + self._client, stream=self._name, consumer=name, timeout=timeout + ) + + +class StreamNameLister(AsyncIterable): + pass + + +class StreamInfoLister(AsyncIterable): + pass + + +class StreamManager(Protocol): + """ + Provides methods for managing streams. + """ + + async def create_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: + """ + Creates a new stream with given config. + """ + ... + + async def update_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: + """ + Updates an existing stream with the given config. + """ + ... + + async def create_or_update_stream( + self, config: StreamConfig, timeout: Optional[float] = None + ) -> Stream: ... + + async def stream(self, name: str, timeout: Optional[float] = None) -> Stream: + """Fetches `StreamInfo` and returns a `Stream` instance for a given stream name.""" + ... + + async def stream_name_by_subject( + self, subject: str, timeout: Optional[float] = None + ) -> str: + """Returns a stream name listening on a given subject.""" + ... + + async def delete_stream(self, name: str, timeout: Optional[float] = None) -> None: + """Removes a stream with given name.""" + ... + + def list_streams(self, timeout: Optional[float] = None) -> StreamInfoLister: + """Returns a `StreamLister` for iterating over stream infos.""" + ... + + def stream_names(self, timeout: Optional[float] = None) -> StreamNameLister: + """Returns a `StreamNameLister` for iterating over stream names.""" + ... + + +def _validate_stream_name(stream_name: str) -> None: + if not stream_name: + raise StreamNameRequiredError() + + invalid_chars = ">*. /\\" + if any(char in stream_name for char in invalid_chars): + raise InvalidStreamNameError(stream_name) diff --git a/tests/test_jetstream.py b/tests/test_jetstream.py new file mode 100644 index 00000000..19de95a0 --- /dev/null +++ b/tests/test_jetstream.py @@ -0,0 +1,805 @@ +import unittest +import asyncio +import time +import nats +import nats.jetstream + +from nats.jetstream.stream import * +from nats.jetstream.consumer import * + +from .utils import IsolatedJetStreamServerTestCase + +class TestJetStream(IsolatedJetStreamServerTestCase): + # Stream Creation Tests + async def test_create_stream_ok(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["FOO.123"]) + created_stream = await jetstream_context.create_stream(stream_config) + self.assertIsNotNone(created_stream) + self.assertEqual(created_stream.cached_info.config.name, "foo") + + await nats_client.close() + + async def test_create_stream_with_metadata(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + metadata = {"foo": "bar", "name": "test"} + stream_config = StreamConfig(name="foo_meta", subjects=["FOO.meta"], metadata=metadata) + created_stream = await jetstream_context.create_stream(stream_config) + self.assertEqual(created_stream.cached_info.config.metadata, metadata) + + await nats_client.close() + + async def test_create_stream_with_metadata_reserved_prefix(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + metadata = {"foo": "bar", "_nats_version": "2.10.0"} + stream_config = StreamConfig(name="foo_meta1", subjects=["FOO.meta1"], metadata=metadata) + created_stream = await jetstream_context.create_stream(stream_config) + self.assertEqual(created_stream.cached_info.config.metadata, metadata) + + await nats_client.close() + + async def test_create_stream_with_empty_context(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo_empty_ctx", subjects=["FOO.ctx"]) + created_stream = await jetstream_context.create_stream(stream_config) + self.assertIsNotNone(created_stream) + + await nats_client.close() + + async def test_create_stream_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(InvalidStreamNameError): + invalid_stream_config = StreamConfig(name="foo.123", subjects=["FOO.123"]) + await jetstream_context.create_stream(invalid_stream_config) + + await nats_client.close() + + async def test_create_stream_name_required(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNameRequiredError): + invalid_stream_config = StreamConfig(name="", subjects=["FOO.123"]) + await jetstream_context.create_stream(invalid_stream_config) + + await nats_client.close() + + async def test_create_stream_name_already_in_use(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["FOO.123"]) + created_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(StreamNameAlreadyInUseError): + await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["BAR.123"])) + + await nats_client.close() + + async def test_create_stream_timeout(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["BAR.123"]) + with self.assertRaises(asyncio.TimeoutError): + await jetstream_context.create_stream(stream_config, timeout=0.00001) + + await nats_client.close() + + # Create or Update Stream Tests + async def test_create_or_update_stream_create_ok(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["FOO.1"]) + created_stream = await jetstream_context.create_stream(stream_config) + self.assertIsNotNone(created_stream) + + await nats_client.close() + + + async def test_create_or_update_stream_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(InvalidStreamNameError): + invalid_stream_config = StreamConfig(name="foo.123", subjects=["FOO-123"]) + await jetstream_context.create_stream(invalid_stream_config) + + await nats_client.close() + + async def test_create_or_update_stream_name_required(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNameRequiredError): + invalid_stream_config = StreamConfig(name="", subjects=["FOO-1234"]) + await jetstream_context.create_stream(invalid_stream_config) + + await nats_client.close() + + async def test_create_or_update_stream_update_ok(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + original_config = StreamConfig(name="foo", subjects=["FOO.1"]) + await jetstream_context.create_stream(original_config) + updated_config = StreamConfig(name="foo", subjects=["BAR-123"]) + updated_stream = await jetstream_context.update_stream(updated_config) + self.assertEqual(updated_stream.cached_info.config.subjects, ["BAR-123"]) + + await nats_client.close() + + async def test_create_or_update_stream_timeout(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["BAR-1234"]) + with self.assertRaises(asyncio.TimeoutError): + await jetstream_context.create_stream(stream_config, timeout=0.000000001) + + await nats_client.close() + + # Update Stream Tests + async def test_update_stream_existing(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + original_config = StreamConfig(name="foo", subjects=["FOO.123"]) + await jetstream_context.create_stream(original_config) + updated_config = StreamConfig(name="foo", subjects=["BAR.123"]) + updated_stream = await jetstream_context.update_stream(updated_config) + self.assertEqual(updated_stream.cached_info.config.subjects, ["BAR.123"]) + + await nats_client.close() + + async def test_update_stream_add_metadata(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + original_config = StreamConfig(name="foo", subjects=["FOO.123"]) + await jetstream_context.create_stream(original_config) + metadata = {"foo": "bar", "name": "test"} + updated_config = StreamConfig(name="foo", subjects=["BAR.123"], metadata=metadata) + updated_stream = await jetstream_context.update_stream(updated_config) + self.assertEqual(updated_stream.cached_info.config.metadata, metadata) + + await nats_client.close() + + async def test_update_stream_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(InvalidStreamNameError): + invalid_config = StreamConfig(name="foo.123", subjects=["FOO.123"]) + await jetstream_context.update_stream(invalid_config) + + await nats_client.close() + + async def test_update_stream_name_required(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNameRequiredError): + invalid_config = StreamConfig(name="", subjects=["FOO.123"]) + await jetstream_context.update_stream(invalid_config) + + await nats_client.close() + + async def test_update_stream_not_found(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + nonexistent_config = StreamConfig(name="bar", subjects=["FOO.123"]) + with self.assertRaises(StreamNotFoundError): + await jetstream_context.update_stream(nonexistent_config) + + await nats_client.close() + + # Get Stream Tests + async def test_get_stream_existing(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["FOO.123"]) + await jetstream_context.create_stream(stream_config) + existing_stream = await jetstream_context.stream("foo") + self.assertIsNotNone(existing_stream) + self.assertEqual(existing_stream.cached_info.config.name, "foo") + + await nats_client.close() + + async def test_get_stream_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(InvalidStreamNameError): + await jetstream_context.stream("foo.123") + + await nats_client.close() + + async def test_get_stream_name_required(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNameRequiredError): + await jetstream_context.stream("") + + await nats_client.close() + + async def test_get_stream_not_found(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNotFoundError): + await jetstream_context.stream("bar") + + await nats_client.close() + + # Delete Stream Tests + async def test_delete_stream_existing(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="foo", subjects=["FOO.123"]) + await jetstream_context.create_stream(stream_config) + await jetstream_context.delete_stream("foo") + with self.assertRaises(StreamNotFoundError): + await jetstream_context.stream("foo") + + await nats_client.close() + + async def test_delete_stream_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(InvalidStreamNameError): + await jetstream_context.delete_stream("foo.123") + + await nats_client.close() + + async def test_delete_stream_name_required(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNameRequiredError): + await jetstream_context.delete_stream("") + + await nats_client.close() + + async def test_delete_stream_not_found(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + with self.assertRaises(StreamNotFoundError): + await jetstream_context.delete_stream("foo") + + await nats_client.close() + + # # List Streams Tests + # async def test_list_streams(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(500): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # streams = [stream async for stream in jetstream_manager.streams()] + # self.assertEqual(len(streams), 500) + + # await nats_client.close() + + # async def test_list_streams_with_subject_filter(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(260): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # streams = [stream async for stream in jetstream_context.streams(subject="FOO.123")] + # self.assertEqual(len(streams), 1) + + # await nats_client.close() + + # async def test_list_streams_with_subject_filter_no_match(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(100): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # streams = [stream async for stream in jetstream_manager.streams(subject="FOO.500")] + # self.assertEqual(len(streams), 0) + + # await nats_client.close() + + # async def test_list_streams_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # streams = [stream async for stream in jetstream_manager.streams()] + + # await nats_client.close() + + # # Stream Names Tests + # async def test_stream_names(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(500): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # names = [name async for name in jetstream_manager.stream_names()] + # self.assertEqual(len(names), 500) + + # await nats_client.close() + + # async def test_stream_names_with_subject_filter(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(260): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # names = [name async for name in jetstream_manager.stream_names(subject="FOO.123")] + # self.assertEqual(len(names), 1) + + # await nats_client.close() + + # async def test_stream_names_with_subject_filter_no_match(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # for i in range(100): + # await jetstream_context.create_stream(StreamConfig(name=f"foo{i}", subjects=[f"FOO.{i}"])) + # names = [name async for name in jetstream_manager.stream_names(subject="FOO.500")] + # self.assertEqual(len(names), 0) + + # await nats_client.close() + + # async def test_stream_names_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # names = [name async for name in jetstream_manager.stream_names()] + + # await nats_client.close() + + # # Stream by Subject Tests + # async def test_stream_name_by_subject_explicit(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # stream_name = await jetstream_manager.stream_name_by_subject("FOO.123") + # self.assertEqual(stream_name, "foo") + + # await nats_client.close() + + # async def test_stream_name_by_subject_wildcard(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="bar", subjects=["BAR.*"])) + # stream_name = await jetstream_manager.stream_name_by_subject("BAR.*") + # self.assertEqual(stream_name, "bar") + + # await nats_client.close() + + # async def test_stream_name_by_subject_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_manager.stream_name_by_subject("BAR.XYZ") + + # await nats_client.close() + + # async def test_stream_name_by_subject_invalid(self): + # nats_client = await nats.connect("nats://localhost:4222") + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_manager.stream_name_by_subject("FOO.>.123") + + # await nats_client.close() + + # async def test_stream_name_by_subject_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # await jetstream_manager.stream_name_by_subject("FOO.123") + + # await nats_client.close() + + # # Consumer Tests + # async def test_create_or_update_consumer_create_durable_pull(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # consumer = await jetstream_context.create_consumer("foo", consumer_config) + # self.assertIsNotNone(consumer) + # self.assertEqual(consumer.name, "dur") + + # await nats_client.close() + + # async def test_create_or_update_consumer_create_ephemeral_pull(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(ack_policy=ConsumerConfig.AckExplicit) + # consumer = await jetstream_context.create_consumer("foo", consumer_config) + # self.assertIsNotNone(consumer) + + # await nats_client.close() + + # async def test_create_or_update_consumer_update(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # await jetstream_context.create_consumer("foo", consumer_config) + # updated_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit, description="test consumer") + # updated_consumer = await jetstream_context.update_consumer("foo", updated_config) + # self.assertEqual(updated_consumer.config.description, "test consumer") + + # await nats_client.close() + + # async def test_create_or_update_consumer_illegal_update(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # await jetstream_context.create_consumer("foo", consumer_config) + # illegal_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckNone) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.update_consumer("foo", illegal_config) + + # await nats_client.close() + + # async def test_create_or_update_consumer_stream_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.create_consumer("nonexistent", consumer_config) + + # await nats_client.close() + + # async def test_create_or_update_consumer_invalid_stream_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.create_consumer("foo.1", consumer_config) + + # await nats_client.close() + + # async def test_create_or_update_consumer_invalid_durable_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur.123", ack_policy=ConsumerConfig.AckExplicit) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.create_consumer("foo", consumer_config) + + # await nats_client.close() + + # async def test_create_or_update_consumer_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # await jetstream_context.create_consumer("foo", consumer_config) + + # await nats_client.close() + + # # Get Consumer Tests + # async def test_get_consumer_existing(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # await jetstream_context.create_consumer("foo", consumer_config) + # consumer = await jetstream_context.consumer_info("foo", "dur") + # self.assertIsNotNone(consumer) + # self.assertEqual(consumer.name, "dur") + + # await nats_client.close() + + # async def test_get_consumer_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.consumer_info("foo", "nonexistent") + + # await nats_client.close() + + # async def test_get_consumer_invalid_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.consumer_info("foo", "dur.123") + + # await nats_client.close() + + # async def test_get_consumer_stream_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.consumer_info("nonexistent", "dur") + + # await nats_client.close() + + # async def test_get_consumer_invalid_stream_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.consumer_info("foo.1", "dur") + + # await nats_client.close() + + # async def test_get_consumer_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # await jetstream_context.create_consumer("foo", consumer_config) + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # await jetstream_context.consumer_info("foo", "dur") + + # await nats_client.close() + + # # Delete Consumer Tests + # async def test_delete_consumer_existing(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.ACK_EXPLICIT) + # await jetstream_context.create_consumer("foo", consumer_config) + # await jetstream_context.delete_consumer("foo", "dur") + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.consumer_info("foo", "dur") + + # await nats_client.close() + + # async def test_delete_consumer_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.delete_consumer("foo", "nonexistent") + + # await nats_client.close() + + # async def test_delete_consumer_invalid_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.delete_consumer("foo", "dur.123") + + # await nats_client.close() + + # async def test_delete_consumer_stream_not_found(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.delete_consumer("nonexistent", "dur") + + # await nats_client.close() + + # async def test_delete_consumer_invalid_stream_name(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # with self.assertRaises(Exception): # Replace with specific exception + # await jetstream_context.delete_consumer("foo.1", "dur") + + # await nats_client.close() + + # async def test_delete_consumer_timeout(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # consumer_config = ConsumerConfig(durable_name="dur", ack_policy=ConsumerConfig.AckExplicit) + # await jetstream_context.create_consumer("foo", consumer_config) + # with self.assertRaises(asyncio.TimeoutError): + # async with asyncio.timeout(timeout=0.000001): + # await jetstream_context.delete_consumer("foo", "dur") + + # await nats_client.close() + + # # JetStream Account Info Tests + # async def test_account_info(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + # jetstream_manager = JetStreamManager(nats_client) + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # info = await jetstream_manager.account_info() + # self.assertIsNotNone(info) + # self.assertGreaterEqual(info.streams, 1) + + # await nats_client.close() + + # Stream Config Tests + async def test_stream_config_matches(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig( + name="stream", + subjects=["foo.*"], + retention=RetentionPolicy.LIMITS, + max_consumers=-1, + max_msgs=-1, + max_bytes=-1, + discard=DiscardPolicy.OLD, + max_age=0, + max_msgs_per_subject=-1, + max_msg_size=-1, + storage=StorageType.FILE, + replicas=1, + no_ack=False, + discard_new_per_subject=False, + duplicates=120 * 1000000000, # 120 seconds in nanoseconds + placement=None, + mirror=None, + sources=None, + sealed=False, + deny_delete=False, + deny_purge=False, + allow_rollup=False, + compression=StoreCompression.NONE, + first_sequence=0, + subject_transform=None, + republish=None, + allow_direct=False, + mirror_direct=False, + ) + + stream = await jetstream_context.create_stream(stream_config) + self.assertEqual(stream.cached_info.config, stream_config) + + await nats_client.close() + + # # Consumer Config Tests + # async def test_consumer_config_matches(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="FOO", subjects=["foo.*"])) + # config = ConsumerConfig( + # durable_name="cons", + # description="test", + # deliver_policy=ConsumerConfig.DeliverAll, + # opt_start_seq=0, + # opt_start_time=None, + # ack_policy=ConsumerConfig.AckExplicit, + # ack_wait=30 * 1000000000, # 30 seconds in nanoseconds + # max_deliver=1, + # filter_subject="", + # replay_policy=ConsumerConfig.ReplayInstant, + # rate_limit_bps=0, + # sample_freq="", + # max_waiting=1, + # max_ack_pending=1000, + # headers_only=False, + # max_batch=0, + # max_expires=0, + # inactive_threshold=0, + # num_replicas=1, + # ) + # consumer = await jetstream_context.create_consumer("FOO", config) + # self.assertEqual(consumer.config, config) + + # await nats_client.close() + + # JetStream Publish Tests + async def test_publish(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + ack = await jetstream_context.publish("FOO.bar", b"Hello World") + self.assertIsNotNone(ack) + self.assertGreater(ack.sequence, 0) + + await nats_client.close() + + # # JetStream Subscribe Tests + # async def test_subscribe_push(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # sub = await jetstream_context.subscribe("FOO.*") + # await jetstream_context.publish("FOO.bar", b"Hello World") + # msg = await sub.next_msg() + # self.assertEqual(msg.data, b"Hello World") + + # await nats_client.close() + + # async def test_subscribe_pull(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # await jetstream_context.create_stream(StreamConfig(name="foo", subjects=["FOO.*"])) + # sub = await jetstream_context.pull_subscribe("FOO.*", "consumer") + # await jetstream_context.publish("FOO.bar", b"Hello World") + # msgs = await sub.fetch(1) + # self.assertEqual(len(msgs), 1) + # self.assertEqual(msgs[0].data, b"Hello World") + + # await nats_client.close() + + # # JetStream Stream Transform Tests + # async def test_stream_transform(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats_client.jetstream() + + # origin_config = StreamConfig( + # name="ORIGIN", + # subjects=["test"], + # storage=StorageType.MEMORY, + # subject_transform=SubjectTransformConfig(source=">", destination="transformed.>") + # ) + # await jetstream_context.create_stream(origin_config) + + # await nats_client.publish("test", b"1") + + # sourcing_config = StreamConfig( + # name="SOURCING", + # storage=StreamConfig.MemoryStorage, + # sources=[ + # StreamSource( + # name="ORIGIN", + # subject_transforms=[ + # StreamConfig.SubjectTransform(src=">", dest="fromtest.>") + # ] + # ) + # ] + # ) + # sourcing_stream = await jetstream_context.create_stream(sourcing_config) + + # consumer_config = ConsumerConfig( + # filter_subject="fromtest.>", + # max_deliver=1, + # ) + # consumer = await sourcing_stream.create_consumer(consumer_config) + + # msg = await consumer.next_msg() + # self.assertEqual(msg.subject, "fromtest.transformed.test") + + # await nats_client.close() diff --git a/tests/test_jetstream_consumer.py b/tests/test_jetstream_consumer.py new file mode 100644 index 00000000..deedae39 --- /dev/null +++ b/tests/test_jetstream_consumer.py @@ -0,0 +1,335 @@ +import unittest +import asyncio +import nats.jetstream +import nats + +from nats.errors import TimeoutError +from tests.utils import IsolatedJetStreamServerTestCase + +class TestPullConsumerFetch(IsolatedJetStreamServerTestCase): + async def publish_test_msgs(self, js): + for msg in self.test_msgs: + await js.publish(self.test_subject, msg.encode()) + + async def test_fetch_no_options(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + msgs = await consumer.fetch(5) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(len(self.test_msgs), len(received_msgs)) + for i, msg in enumerate(received_msgs): + self.assertEqual(self.test_msgs[i], msg.data.decode()) + + self.assertIsNone(msgs.error()) + + async def test_delete_consumer_during_fetch(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + msgs = await consumer.fetch(10) + await asyncio.sleep(0.1) + await stream.delete_consumer(consumer.name) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(len(self.test_msgs), len(received_msgs)) + for i, msg in enumerate(received_msgs): + self.assertEqual(self.test_msgs[i], msg.data.decode()) + + self.assertIsInstance(msgs.error(), ConsumerDeletedError) + + async def test_fetch_single_messages_one_by_one(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + received_msgs = [] + + async def fetch_messages(): + while len(received_msgs) < len(self.test_msgs): + msgs = await consumer.fetch(1) + async for msg in msgs.messages(): + if msg: + received_msgs.append(msg) + if msgs.error(): + return + + task = asyncio.create_task(fetch_messages()) + await asyncio.sleep(0.01) + await self.publish_test_msgs(js) + await task + + self.assertEqual(len(self.test_msgs), len(received_msgs)) + for i, msg in enumerate(received_msgs): + self.assertEqual(self.test_msgs[i], msg.data.decode()) + + async def test_fetch_no_wait_no_messages(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + msgs = await consumer.fetch_no_wait(5) + await asyncio.sleep(0.1) + await self.publish_test_msgs(js) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + + async def test_fetch_no_wait_some_messages_available(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + await asyncio.sleep(0.05) + msgs = await consumer.fetch_no_wait(10) + await asyncio.sleep(0.1) + await self.publish_test_msgs(js) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(len(self.test_msgs), len(received_msgs)) + + async def test_fetch_with_timeout(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + msgs = await consumer.fetch(5, max_wait=0.05) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + + async def test_fetch_with_invalid_timeout(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + with self.assertRaises(ValueError): + await consumer.fetch(5, max_wait=-0.05) + + async def test_fetch_with_missing_heartbeat(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + msgs = await consumer.fetch(5, heartbeat=0.05) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(len(self.test_msgs), len(received_msgs)) + self.assertIsNone(msgs.error()) + + msgs = await consumer.fetch(5, heartbeat=0.05, max_wait=0.2) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + self.assertIsNone(msgs.error()) + + await stream.delete_consumer(consumer.name) + msgs = await consumer.fetch(5, heartbeat=0.05) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + self.assertIsInstance(msgs.error(), TimeoutError) + + async def test_fetch_with_invalid_heartbeat(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + with self.assertRaises(ValueError): + await consumer.fetch(5, heartbeat=20) + + with self.assertRaises(ValueError): + await consumer.fetch(5, heartbeat=2, max_wait=3) + + with self.assertRaises(ValueError): + await consumer.fetch(5, heartbeat=-2) + +class TestPullConsumerFetchBytes(IsolatedJetStreamTest): + async def setUp(self): + await super().setUp() + self.test_subject = "FOO.123" + self.msg = b"0123456789" + + async def publish_test_msgs(self, js, count): + for _ in range(count): + await js.publish(self.test_subject, self.msg) + + async def test_fetch_bytes_exact_count(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT, name="con") + + await self.publish_test_msgs(js, 5) + msgs = await consumer.fetch_bytes(300) + + received_msgs = [] + async for msg in msgs.messages(): + await msg.ack() + received_msgs.append(msg) + + self.assertEqual(5, len(received_msgs)) + self.assertIsNone(msgs.error()) + + async def test_fetch_bytes_last_msg_does_not_fit(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT, name="con") + + await self.publish_test_msgs(js, 5) + msgs = await consumer.fetch_bytes(250) + + received_msgs = [] + async for msg in msgs.messages(): + await msg.ack() + received_msgs.append(msg) + + self.assertEqual(4, len(received_msgs)) + self.assertIsNone(msgs.error()) + + async def test_fetch_bytes_single_msg_too_large(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT, name="con") + + await self.publish_test_msgs(js, 5) + msgs = await consumer.fetch_bytes(30) + + received_msgs = [] + async for msg in msgs.messages(): + await msg.ack() + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + self.assertIsNone(msgs.error()) + + async def test_fetch_bytes_timeout(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT, name="con") + + await self.publish_test_msgs(js, 5) + msgs = await consumer.fetch_bytes(1000, max_wait=0.05) + + received_msgs = [] + async for msg in msgs.messages(): + await msg.ack() + received_msgs.append(msg) + + self.assertEqual(5, len(received_msgs)) + self.assertIsNone(msgs.error()) + + async def test_fetch_bytes_missing_heartbeat(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + msgs = await consumer.fetch_bytes(5, heartbeat=0.05, max_wait=0.2) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + self.assertIsNone(msgs.error()) + + await stream.delete_consumer(consumer.name) + msgs = await consumer.fetch_bytes(5, heartbeat=0.05) + + received_msgs = [] + async for msg in msgs.messages(): + received_msgs.append(msg) + + self.assertEqual(0, len(received_msgs)) + self.assertIsInstance(msgs.error(), TimeoutError) + + async def test_fetch_bytes_invalid_heartbeat(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + with self.assertRaises(ValueError): + await consumer.fetch_bytes(5, heartbeat=20) + + with self.assertRaises(ValueError): + await consumer.fetch_bytes(5, heartbeat=2, max_wait=3) + + with self.assertRaises(ValueError): + await consumer.fetch_bytes(5, heartbeat=-2) + +class TestPullConsumerMessages(IsolatedJetStreamTest): + async def setUp(self): + await super().setUp() + self.test_subject = "FOO.123" + self.test_msgs = ["m1", "m2", "m3", "m4", "m5"] + + async def publish_test_msgs(self, js): + for msg in self.test_msgs: + await js.publish(self.test_subject, msg.encode()) + + async def test_messages_no_options(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + msgs = [] + async with consumer.messages() as iterator: + async for msg in iterator: + msgs.append(msg) + if len(msgs) == len(self.test_msgs): + break + + self.assertEqual(len(self.test_msgs), len(msgs)) + for i, msg in enumerate(msgs): + self.assertEqual(self.test_msgs[i], msg.data.decode()) + + async def test_messages_delete_consumer_during_iteration(self): + js = await jetstream.JetStream.connect(self.nc) + stream = await js.add_stream(name="foo", subjects=["FOO.*"]) + consumer = await stream.create_consumer(jetstream.PullConsumer, ack_policy=jetstream.AckPolicy.EXPLICIT) + + await self.publish_test_msgs(js) + msgs = [] + async with consumer.messages() as iterator: + async for msg in iterator: + msgs.append(msg) + if len(msgs) == len(self.test_msgs): + break + + await stream.delete_consumer(consumer.name) + + with self.assertRaises(ConsumerDeletedError): + async with consumer.messages() as iterator: + async for _ in iterator: + pass diff --git a/tests/test_jetstream_stream.py b/tests/test_jetstream_stream.py new file mode 100644 index 00000000..14e783ad --- /dev/null +++ b/tests/test_jetstream_stream.py @@ -0,0 +1,609 @@ +import asyncio +import nats +import nats.jetstream + +from nats.jetstream.stream import StreamConfig +from tests.utils import IsolatedJetStreamServerTestCase +from nats.jetstream.consumer import AckPolicy, ConsumerConfig, PullConsumer + +class TestJetStreamStream(IsolatedJetStreamServerTestCase): + # CreateConsumer tests + async def test_create_durable_pull_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(durable="durable_consumer") + created_consumer = await test_stream.create_consumer(consumer_config) + self.assertIsNotNone(created_consumer) + self.assertEqual(created_consumer.cached_info.name, "durable_consumer") + + await nats_client.close() + + async def test_create_consumer_idempotent(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(durable="durable_consumer") + first_consumer = await test_stream.create_consumer(consumer_config) + second_consumer = await test_stream.create_consumer(consumer_config) + self.assertEqual(first_consumer.cached_info.name, second_consumer.cached_info.name) + + await nats_client.close() + + async def test_create_ephemeral_pull_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(ack_policy=AckPolicy.NONE) + created_consumer = await test_stream.create_consumer(consumer_config) + self.assertIsNotNone(created_consumer) + + await nats_client.close() + + async def test_create_consumer_with_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subject="TEST.A") + created_consumer = await test_stream.create_consumer(consumer_config) + self.assertIsNotNone(created_consumer) + self.assertEqual(created_consumer.cached_info.config.filter_subject, "TEST.A") + + await nats_client.close() + + async def test_create_consumer_with_metadata(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(metadata={"foo": "bar", "baz": "quux"}) + created_consumer = await test_stream.create_consumer(consumer_config) + self.assertEqual(created_consumer.cached_info.config.metadata, {"foo": "bar", "baz": "quux"}) + + await nats_client.close() + + async def test_create_consumer_with_multiple_filter_subjects(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", "TEST.B"]) + created_consumer = await test_stream.create_consumer(consumer_config) + self.assertIsNotNone(created_consumer) + self.assertEqual(created_consumer.cached_info.config.filter_subjects, ["TEST.A", "TEST.B"]) + + await nats_client.close() + + async def test_create_consumer_with_overlapping_filter_subjects(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.*", "TEST.B"]) + with self.assertRaises(Exception): + await test_stream.create_consumer(consumer_config) + + await nats_client.close() + + async def test_create_consumer_with_filter_subjects_and_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", "TEST.B"], filter_subject="TEST.C") + with self.assertRaises(Exception): + await test_stream.create_consumer(consumer_config) + + await nats_client.close() + + async def test_create_consumer_with_empty_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", ""]) + with self.assertRaises(Exception): + await test_stream.create_consumer(consumer_config) + + await nats_client.close() + + async def test_create_consumer_with_invalid_filter_subject_leading_dot(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subject=".TEST") + with self.assertRaises(Exception): + await test_stream.create_consumer(consumer_config) + + await nats_client.close() + + async def test_create_consumer_with_invalid_filter_subject_trailing_dot(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subject="TEST.") + with self.assertRaises(Exception): + await test_stream.create_consumer(consumer_config) + + await nats_client.close() + + async def test_create_consumer_already_exists(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + await test_stream.create_consumer(ConsumerConfig(durable="durable_consumer")) + + with self.assertRaises(Exception): + await test_stream.create_consumer(ConsumerConfig(durable="durable_consumer", description="test consumer")) + + await nats_client.close() + + # UpdateConsumer tests + async def test_update_consumer_with_existing_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + original_config = ConsumerConfig(name="test_consumer", description="original description") + await test_stream.create_consumer(original_config) + updated_config = ConsumerConfig(name="test_consumer", description="updated description") + updated_consumer = await test_stream.update_consumer(updated_config) + self.assertEqual(updated_consumer.cached_info.config.description, "updated description") + + await nats_client.close() + + async def test_update_consumer_with_metadata(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + original_config = ConsumerConfig(name="test_consumer") + await test_stream.create_consumer(original_config) + updated_config = ConsumerConfig(name="test_consumer", metadata={"foo": "bar", "baz": "quux"}) + updated_consumer = await test_stream.update_consumer(updated_config) + self.assertEqual(updated_consumer.cached_info.config.metadata, {"foo": "bar", "baz": "quux"}) + + await nats_client.close() + + async def test_update_consumer_illegal_consumer_update(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + original_config = ConsumerConfig(name="test_consumer", ack_policy=AckPolicy.EXPLICIT) + await test_stream.create_consumer(original_config) + illegal_config = ConsumerConfig(name="test_consumer", ack_policy=AckPolicy.NONE) + with self.assertRaises(Exception): + await test_stream.update_consumer(illegal_config) + + await nats_client.close() + + async def test_update_non_existent_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + non_existent_config = ConsumerConfig(name="non_existent_consumer") + with self.assertRaises(Exception): + await test_stream.update_consumer(non_existent_config) + + await nats_client.close() + + # Consumer tests + async def test_get_existing_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(durable="durable_consumer") + await test_stream.create_consumer(consumer_config) + retrieved_consumer = await test_stream.consumer("durable_consumer") + self.assertIsNotNone(retrieved_consumer) + self.assertEqual(retrieved_consumer.cached_info.name, "durable_consumer") + + await nats_client.close() + + async def test_get_non_existent_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(Exception): + await test_stream.consumer("non_existent_consumer") + + await nats_client.close() + + async def test_get_consumer_with_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(Exception): + await test_stream.consumer("invalid.consumer.name") + + await nats_client.close() + + # CreateOrUpdateConsumer tests + async def test_create_or_update_durable_pull_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + + consumer_config = ConsumerConfig(durable="durable_consumer") + created_consumer = await test_stream.create_or_update_consumer(consumer_config) + self.assertIsInstance(created_consumer, PullConsumer) + self.assertEqual(created_consumer.cached_info.name, "durable_consumer") + + await nats_client.close() + + async def test_create_or_update_ephemeral_pull_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(ack_policy=AckPolicy.NONE) + created_consumer = await test_stream.create_or_update_consumer(consumer_config) + self.assertIsInstance(created_consumer, PullConsumer) + + await nats_client.close() + + async def test_create_or_update_consumer_with_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subject="TEST.A") + created_consumer = await test_stream.create_or_update_consumer(consumer_config) + self.assertIsInstance(created_consumer, PullConsumer) + self.assertEqual(created_consumer.cached_info.config.filter_subject, "TEST.A") + + await nats_client.close() + + async def test_create_or_update_consumer_with_multiple_filter_subjects(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", "TEST.B"]) + created_consumer = await test_stream.create_or_update_consumer(consumer_config) + self.assertIsInstance(created_consumer, PullConsumer) + self.assertEqual(created_consumer.cached_info.config.filter_subjects, ["TEST.A", "TEST.B"]) + + await nats_client.close() + + async def test_create_or_update_consumer_with_overlapping_filter_subjects(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config=StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.*", "TEST.B"]) + with self.assertRaises(Exception): + await test_stream.create_or_update_consumer(consumer_config) + + await nats_client.close() + + async def test_create_or_update_consumer_with_filter_subjects_and_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", "TEST.B"], filter_subject="TEST.C") + with self.assertRaises(Exception): + await test_stream.create_or_update_consumer(consumer_config) + + await nats_client.close() + + async def test_create_or_update_consumer_with_empty_filter_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(filter_subjects=["TEST.A", ""]) + with self.assertRaises(Exception): + await test_stream.create_or_update_consumer(consumer_config) + + await nats_client.close() + + async def test_create_or_update_existing_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + + consumer_config = ConsumerConfig(durable="durable_consumer") + await test_stream.create_or_update_consumer(consumer_config) + + updated_config = ConsumerConfig(durable="durable_consumer", description="test consumer") + updated_consumer = await test_stream.create_or_update_consumer(updated_config) + self.assertEqual(updated_consumer.cached_info.config.description, "test consumer") + + await nats_client.close() + + async def test_create_or_with_update_illegal_update_of_existing_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + + original_config = ConsumerConfig(durable="durable_consumer_2", ack_policy=AckPolicy.EXPLICIT) + await test_stream.create_or_update_consumer(original_config) + + updated_config = ConsumerConfig(durable="durable_consumer_2", ack_policy=AckPolicy.NONE) + with self.assertRaises(Exception): + await test_stream.create_or_update_consumer(updated_config) + + await nats_client.close() + + async def test_create_or_update_consumer_with_invalid_durable(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + invalid_config = ConsumerConfig(durable="invalid.durable.name") + with self.assertRaises(Exception): + await test_stream.create_or_update_consumer(invalid_config) + + await nats_client.close() + + # DeleteConsumer tests + async def test_delete_existing_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + consumer_config = ConsumerConfig(durable="durable_consumer") + await test_stream.create_consumer(consumer_config) + await test_stream.delete_consumer("durable_consumer") + with self.assertRaises(Exception): + await test_stream.consumer("durable_consumer") + + await nats_client.close() + + async def test_delete_non_existent_consumer(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(Exception): + await test_stream.delete_consumer("non_existent_consumer") + + await nats_client.close() + + async def test_delete_consumer_with_invalid_name(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(Exception): + await test_stream.delete_consumer("invalid.consumer.name") + + await nats_client.close() + + # StreamInfo tests + async def test_stream_info_without_options(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + stream_info = await test_stream.info() + self.assertIsNotNone(stream_info) + self.assertEqual(stream_info.config.name, "test_stream") + + await nats_client.close() + + async def test_stream_info_with_deleted_details(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + for i in range(10): + await jetstream_context.publish("TEST.A", f"msg {i}".encode()) + await test_stream.delete_message(3) + await test_stream.delete_message(5) + stream_info = await test_stream.info(deleted_details=True) + self.assertEqual(stream_info.state.num_deleted, 2) + self.assertEqual(stream_info.state.deleted, [3, 5]) + + await nats_client.close() + + async def test_stream_info_with_subject_filter(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + for i in range(10): + await jetstream_context.publish("TEST.A", f"msg A {i}".encode()) + await jetstream_context.publish("TEST.B", f"msg B {i}".encode()) + stream_info = await test_stream.info(subject_filter="TEST.A") + self.assertEqual(stream_info.state.subjects.get("TEST.A"), 10) + self.assertNotIn("TEST.B", stream_info.state.subjects) + + await nats_client.close() + + async def test_stream_info_timeout(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(asyncio.TimeoutError): + await test_stream.info(timeout=0.00001) + + await nats_client.close() + + # # SubjectsFilterPaging test + # async def test_subjects_filter_paging(self): + # nats_client = await nats.connect("nats://localhost:4222") + # jetstream_context = nats.jetstream.new(nats_client) + + # test_stream = await jetstream_context.create_stream(name="test_stream", subjects=["TEST.*"]) + # for i in range(110000): + # await jetstream_context.publish(f"TEST.{i}", b"data") + + # stream_info = await test_stream.info(subject_filter="TEST.*") + # self.assertEqual(len(stream_info.state.subjects), 110000) + + # cached_info = test_stream.cached_info() + # self.assertEqual(len(cached_info.state.subjects), 0) + + # await nats_client.close() + + # # StreamCachedInfo test + async def test_stream_cached_info(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"], description="original") + test_stream = await jetstream_context.create_stream(stream_config) + self.assertEqual(test_stream.cached_info.config.name, "test_stream") + self.assertEqual(test_stream.cached_info.config.description, "original") + + updated_stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"], description="updated") + await jetstream_context.update_stream(updated_stream_config) + self.assertEqual(test_stream.cached_info.config.description, "original") + + updated_info = await test_stream.info() + self.assertEqual(updated_info.config.description, "updated") + + await nats_client.close() + self.assertEqual(test_stream.cached_info.config.description, "original") + + updated_info = await test_stream.info() + self.assertEqual(updated_info.config.description, "updated") + + await nats_client.close() + + # # GetMsg tests + async def test_get_existing_message(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + await jetstream_context.publish("TEST.A", b"test message") + msg = await test_stream.get_msg(1) + self.assertEqual(msg.data, b"test message") + + await nats_client.close() + + async def test_get_non_existent_message(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + with self.assertRaises(nats.errors.Error): + await test_stream.get_msg(100) + + await nats_client.close() + + async def test_get_deleted_message(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + await jetstream_context.publish("TEST.A", b"test message") + await test_stream.delete_message(1) + with self.assertRaises(nats.errors.Error): + await test_stream.get_msg(1) + + await nats_client.close() + + async def test_get_message_with_headers(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + headers = {"X-Test": "test value"} + await jetstream_context.publish("TEST.A", b"test message", headers=headers) + msg = await test_stream.get_msg(1) + self.assertEqual(msg.data, b"test message") + self.assertEqual(msg.headers.get("X-Test"), "test value") + + await nats_client.close() + + async def test_get_message_context_timeout(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + await jetstream_context.publish("TEST.A", b"test message") + with self.assertRaises(asyncio.TimeoutError): + async with asyncio.timeout(0.001): + await test_stream.get_msg(1) + + await nats_client.close() + + # GetLastMsgForSubject tests + async def test_get_last_message_for_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + for i in range(5): + await jetstream_context.publish("TEST.A", f"msg A {i}".encode()) + await jetstream_context.publish("TEST.B", f"msg B {i}".encode()) + msg = await test_stream.get_last_msg_for_subject("TEST.A") + self.assertEqual(msg.data, b"msg A 4") + + await nats_client.close() + + async def test_get_last_message_for_wildcard_subject(self): + nats_client = await nats.connect("nats://localhost:4222") + jetstream_context = nats.jetstream.new(nats_client) + + stream_config = StreamConfig(name="test_stream", subjects=["TEST.*"]) + test_stream = await jetstream_context.create_stream(stream_config) + for i in range(5): + await jetstream_context.publish(f"TEST.{i}", b"data") + msg = await test_stream.get_last_msg_for_subject("TEST.*") + self.assertEqual(msg.data, b"data") + + await nats_client.close() diff --git a/tests/utils.py b/tests/utils.py index 43dd27fa..7af4d0ff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -470,6 +470,21 @@ def tearDown(self): shutil.rmtree(natsd.store_dir) self.loop.close() +class IsolatedJetStreamServerTestCase(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.server_pool = [] + server = NATSD(port=4222, with_jetstream=True) + self.server_pool.append(server) + for natsd in self.server_pool: + start_natsd(natsd) + + def tearDown(self): + for natsd in self.server_pool: + natsd.stop() + shutil.rmtree(natsd.store_dir) + + self._server_pool = None + class SingleWebSocketServerTestCase(unittest.TestCase):