From 4d7d29159f1663736539a447aba615fd623a1032 Mon Sep 17 00:00:00 2001 From: Marcos Schroh <2828842+marcosschroh@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:56:56 +0200 Subject: [PATCH] feat: Stream error policy added (#206) --- docs/middleware.md | 6 +- docs/stream.md | 63 +++++++------- kstreams/engine.py | 26 ++++-- kstreams/middleware/middleware.py | 32 ++++++-- kstreams/prometheus/monitor.py | 52 ++++++------ kstreams/rebalance_listener.py | 32 +++++--- kstreams/streams.py | 5 ++ kstreams/streams_utils.py | 6 ++ tests/conftest.py | 3 + tests/middleware/test_middleware.py | 4 +- tests/test_client.py | 12 +-- tests/test_consumer.py | 32 +++++--- tests/test_monitor.py | 2 +- tests/test_streams.py | 4 +- tests/test_streams_error_policy.py | 122 ++++++++++++++++++++++++++++ 15 files changed, 291 insertions(+), 110 deletions(-) create mode 100644 tests/test_streams_error_policy.py diff --git a/docs/middleware.md b/docs/middleware.md index 8835c3ee..fd074db4 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -95,11 +95,11 @@ middlewares = [ raise ValueError("Joker received...") ``` -## Middleware by default - -Kstreams includes one middleware by default, `ExceptionMiddleware`. This middleware adds exception handlers, for particular types of expected exception cases, for example when the `Consumer` stops (kafka disconnects), user presses `CTRL+C` or any other exception that could cause the `stream` to crash. +## Default Middleware ::: kstreams.middleware.middleware.ExceptionMiddleware + options: + show_bases: false ## Middleware chain diff --git a/docs/stream.md b/docs/stream.md index 6ee5b1ce..79ff0e42 100644 --- a/docs/stream.md +++ b/docs/stream.md @@ -156,52 +156,51 @@ stream = Stream( ## Stream crashing -If your stream `crashes` for any reason, the event consumption will stop meaning that non event will be consumed from the `topic`. -As an end user you are responsable of deciding what to do. In future version approaches like `re-try`, `stream engine stops on stream crash` might be introduced. +If your stream `crashes` for any reason the event consumption is stopped, meaning that non event will be consumed from the `topic`. However, it is possible to set three different `error policies` per stream: -```python title="Crashing example" -import aiorun +- `StreamErrorPolicy.STOP` (**default**): Stop the `Stream` when an exception occurs. The exception is raised after the stream is properly stopped. +- `StreamErrorPolicy.RESTART`: Stop and restart the `Stream` when an exception occurs. The event that caused the exception is skipped. The exception is *NOT raised* because the application should contine working, however `logger.exception()` is used to alert the user. +- `StreamErrorPolicy.STOP_ENGINE`: Stop the `StreamEngine` when an exception occurs. The exception is raised after *ALL* the Streams were properly stopped. + +In the following example, the `StreamErrorPolicy.RESTART` error policy is specifed. If the `Stream` crashed with the `ValueError` exception it is restarted: + +```python from kstreams import create_engine, ConsumerRecord +from kstreams.stream_utils import StreamErrorPolicy stream_engine = create_engine(title="my-stream-engine") -@stream_engine.stream("local--kstreams", group_id="de-my-partition") +@stream_engine.stream( + "local--hello-world", + group_id="example-group", + error_policy=StreamErrorPolicy.RESTART +) async def stream(cr: ConsumerRecord) -> None: - print(f"Event consumed. Payload {cr.payload}") + if cr.key == b"error": + # Stream will be restarted after the ValueError is raised + raise ValueError("error....") + print(f"Event consumed. Payload {cr.value}") +``` -async def produce(): - await stream_engine.send( - "local--kstreams", - value=b"Hi" - ) - - -async def start(): - await stream_engine.start() - await produce() - +We can see the logs: -async def shutdown(loop): - await stream_engine.stop() +```bash +ValueError: error.... +INFO:aiokafka.consumer.group_coordinator:LeaveGroup request succeeded +INFO:aiokafka.consumer.consumer:Unsubscribed all topics or patterns and assigned partitions +INFO:kstreams.streams:Stream consuming from topics ['local--hello-world'] has stopped!!! -if __name__ == "__main__": - aiorun.run(start(), stop_on_unhandled_errors=True, shutdown_callback=shutdown) +INFO:kstreams.middleware.middleware:Restarting stream +INFO:aiokafka.consumer.subscription_state:Updating subscribed topics to: frozenset({'local--hello-world'}) +... +INFO:aiokafka.consumer.group_coordinator:Setting newly assigned partitions {TopicPartition(topic='local--hello-world', partition=0)} for group example-group ``` -```bash -CRASHED Stream!!! Task .func_wrapper() running at /Users/Projects/kstreams/kstreams/streams.py:55>> - - 'ConsumerRecord' object has no attribute 'payload' -Traceback (most recent call last): - File "/Users/Projects/kstreams/kstreams/streams.py", line 52, in func_wrapper - await self.func(self) - File "/Users/Projects/kstreams/examples/fastapi_example/streaming/streams.py", line 9, in stream - print(f"Event consumed: headers: {cr.headers}, payload: {cr.payload}") -AttributeError: 'ConsumerRecord' object has no attribute 'payload' -``` +!!! note + If you are using `aiorun` with `stop_on_unhandled_errors=True` and the `error_policy` is `StreamErrorPolicy.RESTART` then the `application` will NOT stop as the exception that caused the `Stream` to `crash` is not `raised` ## Changing consumer behavior diff --git a/kstreams/engine.py b/kstreams/engine.py index 792d474b..0144f792 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -17,7 +17,7 @@ from .serializers import Deserializer, Serializer from .streams import Stream, StreamFunc from .streams import stream as stream_func -from .streams_utils import UDFType +from .streams_utils import StreamErrorPolicy, UDFType from .types import EngineHooks, Headers, NextMiddlewareCall from .utils import encode_headers, execute_hooks @@ -342,7 +342,9 @@ def get_stream(self, name: str) -> typing.Optional[Stream]: return stream - def add_stream(self, stream: Stream) -> None: + def add_stream( + self, stream: Stream, error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP + ) -> None: if self.exist_stream(stream.name): raise DuplicateStreamException(name=stream.name) stream.backend = self.backend @@ -367,12 +369,18 @@ def add_stream(self, stream: Stream) -> None: # NOTE: When `no typing` support is deprecated this check can # be removed if stream.udf_handler.type != UDFType.NO_TYPING: - stream.func = self.build_stream_middleware_stack(stream=stream) + stream.func = self.build_stream_middleware_stack( + stream=stream, error_policy=error_policy + ) - def build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall: + def build_stream_middleware_stack( + self, *, stream: Stream, error_policy: StreamErrorPolicy + ) -> NextMiddlewareCall: assert stream.udf_handler, "UdfHandler can not be None" - stream.middlewares = [Middleware(ExceptionMiddleware)] + stream.middlewares + stream.middlewares = [ + Middleware(ExceptionMiddleware, engine=self, error_policy=error_policy), + ] + stream.middlewares next_call = stream.udf_handler for middleware, options in reversed(stream.middlewares): @@ -382,9 +390,12 @@ def build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall return next_call async def remove_stream(self, stream: Stream) -> None: + consumer = stream.consumer self._streams.remove(stream) await stream.stop() - self.monitor.clean_stream_consumer_metrics(stream) + + if consumer is not None: + self.monitor.clean_stream_consumer_metrics(consumer=consumer) def stream( self, @@ -396,6 +407,7 @@ def stream( rebalance_listener: typing.Optional[RebalanceListener] = None, middlewares: typing.Optional[typing.List[Middleware]] = None, subscribe_by_pattern: bool = False, + error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP, **kwargs, ) -> typing.Callable[[StreamFunc], Stream]: def decorator(func: StreamFunc) -> Stream: @@ -409,7 +421,7 @@ def decorator(func: StreamFunc) -> Stream: subscribe_by_pattern=subscribe_by_pattern, **kwargs, )(func) - self.add_stream(stream_from_func) + self.add_stream(stream_from_func, error_policy=error_policy) return stream_from_func diff --git a/kstreams/middleware/middleware.py b/kstreams/middleware/middleware.py index e1830abd..c77bc2c5 100644 --- a/kstreams/middleware/middleware.py +++ b/kstreams/middleware/middleware.py @@ -5,9 +5,10 @@ from aiokafka import errors from kstreams import ConsumerRecord, types +from kstreams.streams_utils import StreamErrorPolicy if typing.TYPE_CHECKING: - from kstreams import Stream # pragma: no cover + from kstreams import Stream, StreamEngine # pragma: no cover logger = logging.getLogger(__name__) @@ -61,12 +62,18 @@ async def __call__(self, cr: ConsumerRecord) -> typing.Any: class ExceptionMiddleware(BaseMiddleware): + def __init__( + self, *, engine: "StreamEngine", error_policy: StreamErrorPolicy, **kwargs + ) -> None: + super().__init__(**kwargs) + self.engine = engine + self.error_policy = error_policy + async def __call__(self, cr: ConsumerRecord) -> typing.Any: try: return await self.next_call(cr) except errors.ConsumerStoppedError as exc: - await self.cleanup_policy() - raise exc + await self.cleanup_policy(exc) except Exception as exc: logger.exception( "Unhandled error occurred while listening to the stream. " @@ -76,14 +83,23 @@ async def __call__(self, cr: ConsumerRecord) -> typing.Any: exc.add_note(f"Handler: {self.stream.func}") exc.add_note(f"Topics: {self.stream.topics}") - await self.cleanup_policy() - raise exc + await self.cleanup_policy(exc) - async def cleanup_policy(self) -> None: + async def cleanup_policy(self, exc: Exception) -> None: # always release the asyncio.Lock `is_processing` to - # stop properly the `stream` + # stop or restart properly the `stream` self.stream.is_processing.release() - await self.stream.stop() + + if self.error_policy == StreamErrorPolicy.RESTART: + await self.stream.stop() + logger.info(f"Restarting stream {self.stream}") + await self.stream.start() + elif self.error_policy == StreamErrorPolicy.STOP: + await self.stream.stop() + raise exc + else: + await self.engine.stop() + raise exc # acquire the asyncio.Lock `is_processing` again to resume the processing # and avoid `RuntimeError: Lock is not acquired.` diff --git a/kstreams/prometheus/monitor.py b/kstreams/prometheus/monitor.py index 6a7e2ced..d0650be3 100644 --- a/kstreams/prometheus/monitor.py +++ b/kstreams/prometheus/monitor.py @@ -124,31 +124,33 @@ def _clean_consumer_metrics(self) -> None: self.MET_POSITION.clear() self.MET_HIGHWATER.clear() - def clean_stream_consumer_metrics(self, stream: Stream) -> None: - if stream.consumer is not None: - topic_partitions = stream.consumer.assignment() - group_id = stream.consumer._group_id - for topic_partition in topic_partitions: - topic = topic_partition.topic - partition = topic_partition.partition - - metrics_found = False - for sample in self.MET_LAG.collect()[0].samples: - if { - "topic": topic, - "partition": str(partition), - "consumer_group": group_id, - } == sample.labels: - metrics_found = True - - if metrics_found: - self.MET_LAG.remove(topic, partition, group_id) - self.MET_POSITION_LAG.remove(topic, partition, group_id) - self.MET_COMMITTED.remove(topic, partition, group_id) - self.MET_POSITION.remove(topic, partition, group_id) - self.MET_HIGHWATER.remove(topic, partition, group_id) - else: - logger.debug(f"Metrics for stream: {stream.name} not found") + def clean_stream_consumer_metrics(self, consumer: Consumer) -> None: + topic_partitions = consumer.assignment() + group_id = consumer._group_id + for topic_partition in topic_partitions: + topic = topic_partition.topic + partition = topic_partition.partition + + metrics_found = False + for sample in self.MET_LAG.collect()[0].samples: + if { + "topic": topic, + "partition": str(partition), + "consumer_group": group_id, + } == sample.labels: + metrics_found = True + + if metrics_found: + self.MET_LAG.remove(topic, partition, group_id) + self.MET_POSITION_LAG.remove(topic, partition, group_id) + self.MET_COMMITTED.remove(topic, partition, group_id) + self.MET_POSITION.remove(topic, partition, group_id) + self.MET_HIGHWATER.remove(topic, partition, group_id) + else: + logger.debug( + "Metrics for consumer with group-id: " + f"{consumer._group_id} not found" + ) def add_producer(self, producer): self._producer = producer diff --git a/kstreams/rebalance_listener.py b/kstreams/rebalance_listener.py index facde047..9421fecc 100644 --- a/kstreams/rebalance_listener.py +++ b/kstreams/rebalance_listener.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Set +import typing from aiokafka.abc import ConsumerRebalanceListener @@ -8,6 +8,9 @@ logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + from kstreams import Stream, StreamEngine # pragma: no cover + # Can not use a Protocol here because aiokafka forces to have a concrete instance # that inherits from ConsumerRebalanceListener, if we use a protocol we will @@ -49,11 +52,11 @@ async def my_stream(stream: Stream): """ def __init__(self) -> None: - self.stream = None + self.stream: typing.Optional["Stream"] = None # engine added so it can react on rebalance events - self.engine = None + self.engine: typing.Optional["StreamEngine"] = None - async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: + async def on_partitions_revoked(self, revoked: typing.Set[TopicPartition]) -> None: """ Coroutine to be called *before* a rebalance operation starts and *after* the consumer stops fetching data. @@ -74,7 +77,9 @@ async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: """ ... # pragma: no cover - async def on_partitions_assigned(self, assigned: Set[TopicPartition]) -> None: + async def on_partitions_assigned( + self, assigned: typing.Set[TopicPartition] + ) -> None: """ Coroutine to be called *after* partition re-assignment completes and *before* the consumer starts fetching data again. @@ -98,7 +103,7 @@ async def on_partitions_assigned(self, assigned: Set[TopicPartition]) -> None: class MetricsRebalanceListener(RebalanceListener): - async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: + async def on_partitions_revoked(self, revoked: typing.Set[TopicPartition]) -> None: """ Coroutine to be called *before* a rebalance operation starts and *after* the consumer stops fetching data. @@ -112,10 +117,14 @@ async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: # lock all asyncio Tasks so no new metrics will be added to the Monitor if revoked and self.engine is not None: async with asyncio.Lock(): - if self.stream is not None: - self.engine.monitor.clean_stream_consumer_metrics(self.stream) - - async def on_partitions_assigned(self, assigned: Set[TopicPartition]) -> None: + if self.stream is not None and self.stream.consumer is not None: + self.engine.monitor.clean_stream_consumer_metrics( + self.stream.consumer + ) + + async def on_partitions_assigned( + self, assigned: typing.Set[TopicPartition] + ) -> None: """ Coroutine to be called *after* partition re-assignment completes and *before* the consumer starts fetching data again. @@ -134,7 +143,7 @@ async def on_partitions_assigned(self, assigned: Set[TopicPartition]) -> None: class ManualCommitRebalanceListener(MetricsRebalanceListener): - async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: + async def on_partitions_revoked(self, revoked: typing.Set[TopicPartition]) -> None: """ Coroutine to be called *before* a rebalance operation starts and *after* the consumer stops fetching data. @@ -150,6 +159,7 @@ async def on_partitions_revoked(self, revoked: Set[TopicPartition]) -> None: if ( revoked and self.stream is not None + and self.stream.consumer is not None and not self.stream.consumer._enable_auto_commit ): logger.info( diff --git a/kstreams/streams.py b/kstreams/streams.py index 65f61350..f6ff9f61 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -187,6 +187,11 @@ async def stop(self) -> None: if self.consumer is not None: await self.consumer.stop() + # we have to do this operations because aiokafka bug + # https://github.com/aio-libs/aiokafka/issues/1010 + self.consumer.unsubscribe() + self.consumer = None + logger.info( f"Stream consuming from topics {self.topics} has stopped!!! \n\n" ) diff --git a/kstreams/streams_utils.py b/kstreams/streams_utils.py index af6d6b2e..ae38e733 100644 --- a/kstreams/streams_utils.py +++ b/kstreams/streams_utils.py @@ -10,6 +10,12 @@ class UDFType(str, enum.Enum): WITH_TYPING = "WITH_TYPING" +class StreamErrorPolicy(str, enum.Enum): + RESTART = "RESTART" + STOP = "STOP" + STOP_ENGINE = "STOP_ENGINE" + + def setup_type(params: List[inspect.Parameter]) -> UDFType: """ Inspect the user defined function (coroutine) to get the proper way to call it diff --git a/tests/conftest.py b/tests/conftest.py index 452638f5..a491aaea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,6 +62,9 @@ def subscribe( ) -> None: self.topics = topics + def unsubscribe(self) -> None: + ... + def assignment(self): return self._assigments diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py index f7979a4c..c7198c31 100644 --- a/tests/middleware/test_middleware.py +++ b/tests/middleware/test_middleware.py @@ -182,7 +182,7 @@ async def stream(cr: ConsumerRecord): await client.send(topic, value=b"test") assert not stream.running - assert stream.consumer._closed + assert stream.consumer is None @pytest.mark.asyncio @@ -199,4 +199,4 @@ async def stream(cr: ConsumerRecord): await client.send(topic, value=b"test") assert not stream.running - assert stream.consumer._closed + assert stream.consumer is None diff --git a/tests/test_client.py b/tests/test_client.py index 0ec46f36..b891829e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -370,7 +370,7 @@ async def test_partitions_for_topic(stream_engine: StreamEngine): client = TestStreamClient(stream_engine) @stream_engine.stream(topic_name, name="my-stream") - async def consume(stream): + async def stream(stream): async for cr in stream: ... @@ -380,8 +380,8 @@ async def consume(stream): await client.send(topic_name, value=value, key=key, partition=2) await client.send(topic_name, value=value, key=key, partition=10) - stream = stream_engine.get_stream("my-stream") - assert stream.consumer.partitions_for_topic(topic_name) == set([0, 1, 2, 10]) + await asyncio.sleep(1e-10) + assert stream.consumer.partitions_for_topic(topic_name) == set([0, 1, 2, 10]) @pytest.mark.asyncio @@ -443,9 +443,9 @@ async def my_stream(stream: Stream): ) assert record_metadata.partition == partition - # check that everything was commited - stream = stream_engine.get_stream(name) - assert (await stream.consumer.committed(tp)) == total_events - 1 + await asyncio.sleep(1e-10) + # check that everything was commited + assert (await my_stream.consumer.committed(tp)) == total_events - 1 @pytest.mark.asyncio diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 94c10906..69e5d40a 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -69,11 +69,15 @@ async def on_partitions_assigned(self, assigned: Set[TopicPartition]) -> None: rebalance_listener = MyRebalanceListener() - with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"), mock.patch( - "kstreams.clients.aiokafka.AIOKafkaProducer.start" - ): + with mock.patch.multiple( + Consumer, start=mock.DEFAULT, unsubscribe=mock.DEFAULT + ), mock.patch("kstreams.clients.aiokafka.AIOKafkaProducer.start"): - @stream_engine.stream(topic, rebalance_listener=rebalance_listener) + @stream_engine.stream( + topic, + rebalance_listener=rebalance_listener, + group_id="example-group", + ) async def my_stream(stream: Stream): async for _ in stream: ... @@ -98,15 +102,17 @@ async def test_stream_with_default_rebalance_listener(): topic = "local--hello-kpn" topic_partitions = set(TopicPartition(topic=topic, partition=0)) - with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"), mock.patch( - "kstreams.clients.aiokafka.AIOKafkaProducer.start" - ), mock.patch("kstreams.PrometheusMonitor.start") as monitor_start, mock.patch( + with mock.patch.multiple( + Consumer, start=mock.DEFAULT, unsubscribe=mock.DEFAULT + ), mock.patch("kstreams.clients.aiokafka.AIOKafkaProducer.start"), mock.patch( + "kstreams.PrometheusMonitor.start" + ) as monitor_start, mock.patch( "kstreams.PrometheusMonitor.clean_stream_consumer_metrics" ) as clean_stream_metrics: # use this function so we can mock PrometheusMonitor stream_engine = create_engine() - @stream_engine.stream(topic) + @stream_engine.stream(topic, group_id="example-group") async def my_stream(stream: Stream): async for _ in stream: ... @@ -132,7 +138,7 @@ async def my_stream(stream: Stream): monitor_start.assert_awaited_once() # called once on Rebalance with the Stream instance - clean_stream_metrics.assert_called_once_with(my_stream) + clean_stream_metrics.assert_called_once_with(my_stream.consumer) await stream_engine.stop() assert not my_stream.running @@ -143,9 +149,9 @@ async def test_stream_manual_commit_rebalance_listener(stream_engine: StreamEngi topic = "local--hello-kpn" topic_partitions = set(TopicPartition(topic=topic, partition=0)) - with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"), mock.patch( - "kstreams.clients.aiokafka.AIOKafkaConsumer.commit" - ) as commit_mock, mock.patch("kstreams.clients.aiokafka.AIOKafkaProducer.start"): + with mock.patch.multiple( + Consumer, start=mock.DEFAULT, commit=mock.DEFAULT, unsubscribe=mock.DEFAULT + ), mock.patch("kstreams.clients.aiokafka.AIOKafkaProducer.start"): @stream_engine.stream( topic, @@ -171,7 +177,7 @@ async def hello_stream(stream: Stream): ) await rebalance_listener.on_partitions_revoked(revoked=topic_partitions) - commit_mock.assert_awaited_once() + Consumer.commit.assert_awaited_once() await stream_engine.stop() await stream_engine.clean_streams() diff --git a/tests/test_monitor.py b/tests/test_monitor.py index bab525fc..270b41a2 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -223,4 +223,4 @@ async def my_coroutine(_): assert len(stream_engine.monitor.MET_POSITION_LAG.collect()[0].samples) == 0 await stream_engine.remove_stream(stream) - assert "Metrics for stream: my-stream-name not found" in caplog.text + assert "Metrics for consumer with group-id: my-group not found" in caplog.text diff --git a/tests/test_streams.py b/tests/test_streams.py index 6741c758..567490b8 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -315,7 +315,7 @@ async def streaming_fn(_): @pytest.mark.asyncio -async def test_no_recreate_consumer_on_re_start_stream( +async def test_recreate_consumer_on_re_start_stream( stream_engine: StreamEngine, consumer_record_factory ): topic_name = "local--kstreams" @@ -340,7 +340,7 @@ async def stream(my_stream): consumer = stream.consumer await stream.stop() await stream.start() - assert consumer is stream.consumer + assert consumer is not stream.consumer @pytest.mark.asyncio diff --git a/tests/test_streams_error_policy.py b/tests/test_streams_error_policy.py new file mode 100644 index 00000000..eb9cc421 --- /dev/null +++ b/tests/test_streams_error_policy.py @@ -0,0 +1,122 @@ +import asyncio +from unittest import mock + +import pytest + +from kstreams import ConsumerRecord, StreamEngine, TestStreamClient +from kstreams.streams_utils import StreamErrorPolicy + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "stream_options", ({}, {"error_policy": StreamErrorPolicy.STOP}) +) +async def test_stop_stream_error_policy(stream_engine: StreamEngine, stream_options): + event = b'{"message": "Hello world!"}' + topic = "kstrems--local" + topic_two = "kstrems--local-two" + save_to_db = mock.Mock() + client = TestStreamClient(stream_engine) + + @stream_engine.stream(topic, **stream_options) + async def my_stream(cr: ConsumerRecord): + save_to_db(value=cr.value, key=cr.key) + raise ValueError("Crashing Stream...") + + @stream_engine.stream(topic_two) + async def my_stream_two(cr: ConsumerRecord): + save_to_db(cr.value) + + async with client: + await client.send(topic, value=event, key="1") + + # sleep so the event loop can switch context + await asyncio.sleep(1e-10) + + await client.send(topic_two, value=event, key="2") + + # Streams was stopped before leaving the context + assert not my_stream.running + + # Streams still running before leaving the context + assert my_stream_two.running + + # check that mock was called by both Streams + save_to_db.assert_has_calls( + [ + mock.call(value=b'{"message": "Hello world!"}', key="1"), + mock.call(b'{"message": "Hello world!"}'), + ] + ) + + +@pytest.mark.asyncio +async def test_stop_engine_error_policy(stream_engine: StreamEngine): + event = b'{"message": "Hello world!"}' + topic = "kstrems--local" + topic_two = "kstrems--local-two" + save_to_db = mock.Mock() + client = TestStreamClient(stream_engine) + + @stream_engine.stream(topic, error_policy=StreamErrorPolicy.STOP_ENGINE) + async def my_stream(cr: ConsumerRecord): + raise ValueError("Crashing Stream...") + + @stream_engine.stream(topic_two) + async def my_stream_two(cr: ConsumerRecord): + save_to_db(cr.value) + + async with client: + # send event and crash the first Stream, then the second one + # should be stopped because of StreamErrorPolicy.STOP_ENGINE + await client.send(topic, value=event, key="1") + + # Send an event to the second Stream, it should be consumed + # as the Stream has been stopped + await client.send(topic_two, value=event, key="1") + + # Both streams are stopped before leaving the context + assert not my_stream.running + assert not my_stream_two.running + + # check that the event was consumed only once. + # The StreamEngine must wait for graceful shutdown + save_to_db.assert_called_once_with(b'{"message": "Hello world!"}') + + +@pytest.mark.asyncio +async def test_restart_stream_error_policy(stream_engine: StreamEngine): + event = b'{"message": "Hello world!"}' + topic = "kstrems--local-kskss" + save_to_db = mock.Mock() + client = TestStreamClient(stream_engine) + + @stream_engine.stream(topic, error_policy=StreamErrorPolicy.RESTART) + async def my_stream(cr: ConsumerRecord): + if cr.key == "1": + raise ValueError("Crashing Stream...") + save_to_db(value=cr.value, key=cr.key) + + async with client: + await client.send(topic, value=event, key="2") + + # send event to crash the Stream but it should be restarted + # because of StreamErrorPolicy.RESTART + await client.send(topic, value=event, key="1") + + # Send another event to make sure that the Stream is not dead + await client.send(topic, value=event, key="3") + + # sleep so the event loop can switch context + await asyncio.sleep(1e-10) + + # Both streams are stopped before leaving the context + assert my_stream.running + + # check that the Stream has consumed two events + save_to_db.assert_has_calls( + [ + mock.call(value=b'{"message": "Hello world!"}', key="2"), + mock.call(value=b'{"message": "Hello world!"}', key="3"), + ] + )