From 7307bb0476204ca5158354d2079134eab41520c9 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Mon, 30 Oct 2023 12:54:53 +0100 Subject: [PATCH] Add a parallel mode to the stream listener (#16) * Add a parallel mode to the stream listener * fix runtime error message check * linter fix * mypy fix * raise the exception properly * fix my comments --------- Co-authored-by: Matthias Veit --- fixcloudutils/redis/event_stream.py | 54 ++++++++++-- tests/conftest.py | 10 ++- tests/event_stream_test.py | 124 +++++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 13 deletions(-) diff --git a/fixcloudutils/redis/event_stream.py b/fixcloudutils/redis/event_stream.py index 0a7cdac..a8c81a9 100644 --- a/fixcloudutils/redis/event_stream.py +++ b/fixcloudutils/redis/event_stream.py @@ -33,6 +33,7 @@ import sys import uuid from asyncio import Task +from collections import defaultdict from contextlib import suppress from datetime import datetime, timedelta from functools import partial @@ -45,10 +46,12 @@ TypeVar, Dict, List, + Set, ) from attrs import define from redis.asyncio import Redis +from redis.typing import StreamIdT from fixcloudutils.asyncio import stop_running_task from fixcloudutils.asyncio.periodic import Periodic @@ -72,6 +75,7 @@ class Backoff: base_delay: float maximum_delay: float retries: int + log_failed_attempts: bool = True def wait_time(self, attempt: int) -> float: delay: float = self.base_delay * (2**attempt + random.uniform(0, 1)) @@ -83,7 +87,8 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) - except Exception as e: if attempt < self.retries: delay = self.wait_time(attempt) - log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}") + if self.log_failed_attempts: + log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}") await asyncio.sleep(delay) return await self.with_backoff(fn, attempt + 1) else: @@ -91,6 +96,7 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) - NoBackoff = Backoff(0, 0, 0) +DefaultBackoff = Backoff(0.1, 10, 10) @define(frozen=True, slots=True) @@ -123,7 +129,8 @@ def __init__( consider_failed_after: timedelta, batch_size: int = 1000, stop_on_fail: bool = False, - backoff: Optional[Backoff] = Backoff(0.1, 10, 10), + backoff: Optional[Dict[str, Backoff]] = None, + parallelism: Optional[int] = None, ) -> None: """ Create a RedisStream client. @@ -137,7 +144,9 @@ def __init__( :param consider_failed_after: The time after which a message is considered failed and will be retried. :param batch_size: The number of events to read in one batch. :param stop_on_fail: If True, the listener will stop if a failed event is retried too many times. - :param backoff: The backoff strategy to use when retrying failed events. + :param backoff: The backoff strategy for the defined message kind to use when retrying failed events. + The DefaultBackoff is used if no value is provided. + :param parallelism: If provided, messages will be processed in parallel without order. """ self.redis = redis self.stream = stream @@ -146,7 +155,7 @@ def __init__( self.message_processor = message_processor self.batch_size = batch_size self.stop_on_fail = stop_on_fail - self.backoff = backoff or NoBackoff + self.backoff = defaultdict(lambda: DefaultBackoff) if backoff is None else backoff self.__should_run = True self.__listen_task: Optional[Task[Any]] = None # Check for messages that are not processed for a long time by any listener. Try to claim and process them. @@ -157,6 +166,8 @@ def __init__( first_run=timedelta(seconds=3), ) self.__readpos = ">" + self._ongoing_tasks: Set[Task[Any]] = set() + self.parallelism = parallelism async def _listen(self) -> None: while self.__should_run: @@ -165,9 +176,13 @@ async def _listen(self) -> None: self.group, self.listener, {self.stream: self.__readpos}, count=self.batch_size, block=1000 ) self.__readpos = ">" - - await self._handle_stream_messages(messages) + if self.parallelism: + await self._handle_stream_messages_parallel(messages, self.parallelism) + else: + await self._handle_stream_messages(messages) except Exception as e: + if isinstance(e, RuntimeError) and len(e.args) and e.args[0] == "no running event loop": + raise e log.error(f"Failed to read from stream {self.stream}: {e}", exc_info=True) if self.stop_on_fail: raise @@ -185,19 +200,41 @@ async def _handle_stream_messages(self, messages: List[Any]) -> None: # acknowledge all processed messages await self.redis.xack(self.stream, self.group, *ids) + async def _handle_stream_messages_parallel(self, messages: List[Any], max_parallelism: int) -> None: + """ + Handle messages in parallel in an unordered fashion. The number of parallel tasks is limited by max_parallelism. + """ + + async def handle_and_ack(msg: Any, message_id: StreamIdT) -> None: + await self._handle_single_message(msg) + await self.redis.xack(self.stream, self.group, message_id) + + def task_done_callback(task: Task[Any]) -> None: + self._ongoing_tasks.discard(task) + + for stream, stream_messages in messages: + log.debug(f"Handle {len(stream_messages)} messages from stream.") + for uid, data in stream_messages: + while len(self._ongoing_tasks) >= max_parallelism: # queue is full, wait for a slot to be freed + await asyncio.wait(self._ongoing_tasks, return_when=asyncio.FIRST_COMPLETED) + task = asyncio.create_task(handle_and_ack(data, uid), name=f"handle_message_{uid}") + task.add_done_callback(task_done_callback) + self._ongoing_tasks.add(task) + async def _handle_single_message(self, message: Json) -> None: try: if "id" in message and "at" in message and "data" in message: + kind = message["kind"] context = MessageContext( id=message["id"], - kind=message["kind"], + kind=kind, publisher=message["publisher"], sent_at=parse_utc_str(message["at"]), received_at=utc(), ) data = json.loads(message["data"]) log.debug(f"Received message {self.listener}: message {context} data: {data}") - await self.backoff.with_backoff(partial(self.message_processor, data, context)) + await self.backoff[kind].with_backoff(partial(self.message_processor, data, context)) else: log.warning(f"Invalid message format: {message}. Ignore.") except Exception as e: @@ -271,6 +308,7 @@ async def read_all() -> None: await self.__outdated_messages_task.start() async def stop(self) -> Any: + await asyncio.gather(*[stop_running_task(task) for task in self._ongoing_tasks]) self.__should_run = False await self.__outdated_messages_task.stop() await stop_running_task(self.__listen_task) diff --git a/tests/conftest.py b/tests/conftest.py index c3573ee..3dcd4da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,22 +25,24 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List +from typing import List, AsyncIterator from arango.client import ArangoClient from attr import define from pytest import fixture from redis.asyncio import Redis -from redis.backoff import ExponentialBackoff from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff from fixcloudutils.arangodb.async_arangodb import AsyncArangoDB @fixture -def redis() -> Redis: +async def redis() -> AsyncIterator[Redis]: backoff = ExponentialBackoff() # type: ignore - return Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10)) + redis = Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10)) + yield redis + await redis.close(True) @fixture diff --git a/tests/event_stream_test.py b/tests/event_stream_test.py index c7298e5..88fa97f 100644 --- a/tests/event_stream_test.py +++ b/tests/event_stream_test.py @@ -128,6 +128,122 @@ async def check_all_arrived(expected_reader: int) -> bool: await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq") +@pytest.mark.asyncio +@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running") +async def test_stream_parallel(redis: Redis) -> None: + counter: List[int] = [0] + + async def handle_message(group: int, uid: int, message: Json, _: MessageContext) -> None: + # make sure we can read the message + data = structure(message, ExampleData) + assert data.bar == "foo" + assert data.bla == [1, 2, 3] + await asyncio.sleep(0.5) # message takes time to be processed + counter[0] += 1 + + # clean slate + await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq") + + # create a single listener + stream = RedisStreamListener( + redis, "test-stream", "group", "id", partial(handle_message, 1, 1), timedelta(seconds=1), parallelism=10 + ) + await stream.start() + + messages_total = 10 + # publish 10 messages + publisher = RedisStreamPublisher(redis, "test-stream", "test") + for i in range(messages_total): + await publisher.publish("test_data", unstructure(ExampleData(i, "foo", [1, 2, 3]))) + + # make sure messages are in the stream + assert (await redis.xlen("test-stream")) == messages_total + + # expect 10 messages per listener --> 100 messages + async def check_all_arrived(expected_reader: int) -> bool: + while True: + if counter[0] == expected_reader: + return True + await asyncio.sleep(0.1) + + # processing must be parallel and we won't hit a timeout error + # if the parallelism is not working then the processing will take 5 seconds + # and the test will fail + await asyncio.wait_for(check_all_arrived(messages_total), timeout=2) + + # messages must be acked and not be processed again + await asyncio.sleep(1) + assert counter[0] == messages_total + + # no tasks should be running once everything is processed + assert len(stream._ongoing_tasks) == 0 + + # stop all listeners + await stream.stop() + + # don't leave any traces + await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq") + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running") +async def test_stream_parallel_backpressure(redis: Redis) -> None: + counter: List[int] = [0] + + async def handle_message(group: int, uid: int, message: Json, _: MessageContext) -> None: + # make sure we can read the message + data = structure(message, ExampleData) + assert data.bar == "foo" + assert data.bla == [1, 2, 3] + await asyncio.sleep(0.15) # message takes time to be processed + counter[0] += 1 + + # clean slate + await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq") + + # create a single listener + stream = RedisStreamListener( + redis, "test-stream", "group", "id", partial(handle_message, 1, 1), timedelta(seconds=1), parallelism=1 + ) + await stream.start() + + messages_total = 10 + # publish 10 messages + publisher = RedisStreamPublisher(redis, "test-stream", "test") + for i in range(messages_total): + await publisher.publish("test_data", unstructure(ExampleData(i, "foo", [1, 2, 3]))) + + # make sure messages are in the stream + assert (await redis.xlen("test-stream")) == messages_total + + # expect 10 messages per listener --> 100 messages + async def check_all_arrived(expected_reader: int) -> bool: + while True: + if counter[0] == expected_reader: + return True + await asyncio.sleep(0.1) + + # if the parallelism is full we should wait before enqueueing the next message + # the total processing time should at least be 1.5 seconds (10 messages * 0.15 seconds) + before = asyncio.get_running_loop().time() + await asyncio.wait_for(check_all_arrived(messages_total), timeout=2) + after = asyncio.get_running_loop().time() + assert after - before >= 1.5 + + # messages must be acked and not be processed again + await asyncio.sleep(1) + assert counter[0] == messages_total + + # no tasks should be running once everything is processed + assert len(stream._ongoing_tasks) == 0 + + # stop all listeners + await stream.stop() + + # don't leave any traces + await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq") + + @pytest.mark.asyncio @pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running") async def test_stream_pending(redis: Redis) -> None: @@ -182,7 +298,13 @@ async def handle_message(message: Json, context: Any) -> None: # a new redis listener started later will receive all messages async with RedisStreamListener( - redis, "test-stream", "t1", "l1", handle_message, timedelta(seconds=5), backoff=Backoff(0, 0, 5) + redis, + "test-stream", + "t1", + "l1", + handle_message, + timedelta(seconds=5), + backoff=defaultdict(lambda: Backoff(0, 0, 5)), ): async with RedisStreamPublisher(redis, "test-stream", "test") as publisher: await publisher.publish("test_data", unstructure(ExampleData(1, "foo", [1, 2, 3])))