diff --git a/fixcloudutils/redis/event_stream.py b/fixcloudutils/redis/event_stream.py index eb4a2c0..a8c81a9 100644 --- a/fixcloudutils/redis/event_stream.py +++ b/fixcloudutils/redis/event_stream.py @@ -33,6 +33,7 @@ import sys import uuid from asyncio import Task +from collections import defaultdict from contextlib import suppress from datetime import datetime, timedelta from functools import partial @@ -74,6 +75,7 @@ class Backoff: base_delay: float maximum_delay: float retries: int + log_failed_attempts: bool = True def wait_time(self, attempt: int) -> float: delay: float = self.base_delay * (2**attempt + random.uniform(0, 1)) @@ -85,7 +87,8 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) - except Exception as e: if attempt < self.retries: delay = self.wait_time(attempt) - log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}") + if self.log_failed_attempts: + log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}") await asyncio.sleep(delay) return await self.with_backoff(fn, attempt + 1) else: @@ -93,6 +96,7 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) - NoBackoff = Backoff(0, 0, 0) +DefaultBackoff = Backoff(0.1, 10, 10) @define(frozen=True, slots=True) @@ -125,7 +129,7 @@ def __init__( consider_failed_after: timedelta, batch_size: int = 1000, stop_on_fail: bool = False, - backoff: Optional[Backoff] = Backoff(0.1, 10, 10), + backoff: Optional[Dict[str, Backoff]] = None, parallelism: Optional[int] = None, ) -> None: """ @@ -140,7 +144,8 @@ def __init__( :param consider_failed_after: The time after which a message is considered failed and will be retried. :param batch_size: The number of events to read in one batch. :param stop_on_fail: If True, the listener will stop if a failed event is retried too many times. - :param backoff: The backoff strategy to use when retrying failed events. + :param backoff: The backoff strategy for the defined message kind to use when retrying failed events. + The DefaultBackoff is used if no value is provided. :param parallelism: If provided, messages will be processed in parallel without order. """ self.redis = redis @@ -150,7 +155,7 @@ def __init__( self.message_processor = message_processor self.batch_size = batch_size self.stop_on_fail = stop_on_fail - self.backoff = backoff or NoBackoff + self.backoff = defaultdict(lambda: DefaultBackoff) if backoff is None else backoff self.__should_run = True self.__listen_task: Optional[Task[Any]] = None # Check for messages that are not processed for a long time by any listener. Try to claim and process them. @@ -195,9 +200,9 @@ async def _handle_stream_messages(self, messages: List[Any]) -> None: # acknowledge all processed messages await self.redis.xack(self.stream, self.group, *ids) - async def _handle_stream_messages_parallel(self, messages: List[Any], max_parallelesm: int) -> None: + async def _handle_stream_messages_parallel(self, messages: List[Any], max_parallelism: int) -> None: """ - Handle messages in parallel in unordered fasion. The number of parallel tasks is limited by max_parallelism. + Handle messages in parallel in an unordered fashion. The number of parallel tasks is limited by max_parallelism. """ async def handle_and_ack(msg: Any, message_id: StreamIdT) -> None: @@ -210,9 +215,8 @@ def task_done_callback(task: Task[Any]) -> None: for stream, stream_messages in messages: log.debug(f"Handle {len(stream_messages)} messages from stream.") for uid, data in stream_messages: - if len(self._ongoing_tasks) >= max_parallelesm: # queue is full, wait for a slot to be freed + while len(self._ongoing_tasks) >= max_parallelism: # queue is full, wait for a slot to be freed await asyncio.wait(self._ongoing_tasks, return_when=asyncio.FIRST_COMPLETED) - task = asyncio.create_task(handle_and_ack(data, uid), name=f"handle_message_{uid}") task.add_done_callback(task_done_callback) self._ongoing_tasks.add(task) @@ -220,16 +224,17 @@ def task_done_callback(task: Task[Any]) -> None: async def _handle_single_message(self, message: Json) -> None: try: if "id" in message and "at" in message and "data" in message: + kind = message["kind"] context = MessageContext( id=message["id"], - kind=message["kind"], + kind=kind, publisher=message["publisher"], sent_at=parse_utc_str(message["at"]), received_at=utc(), ) data = json.loads(message["data"]) log.debug(f"Received message {self.listener}: message {context} data: {data}") - await self.backoff.with_backoff(partial(self.message_processor, data, context)) + await self.backoff[kind].with_backoff(partial(self.message_processor, data, context)) else: log.warning(f"Invalid message format: {message}. Ignore.") except Exception as e: @@ -303,12 +308,7 @@ async def read_all() -> None: await self.__outdated_messages_task.start() async def stop(self) -> Any: - async def stop_task(task: Task[Any]) -> None: - task.cancel() - with suppress(asyncio.CancelledError): - await task - - await asyncio.gather(*[stop_task(task) for task in self._ongoing_tasks]) + await asyncio.gather(*[stop_running_task(task) for task in self._ongoing_tasks]) self.__should_run = False await self.__outdated_messages_task.stop() await stop_running_task(self.__listen_task) diff --git a/tests/conftest.py b/tests/conftest.py index c3573ee..3dcd4da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,22 +25,24 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List +from typing import List, AsyncIterator from arango.client import ArangoClient from attr import define from pytest import fixture from redis.asyncio import Redis -from redis.backoff import ExponentialBackoff from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff from fixcloudutils.arangodb.async_arangodb import AsyncArangoDB @fixture -def redis() -> Redis: +async def redis() -> AsyncIterator[Redis]: backoff = ExponentialBackoff() # type: ignore - return Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10)) + redis = Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10)) + yield redis + await redis.close(True) @fixture diff --git a/tests/event_stream_test.py b/tests/event_stream_test.py index 89c70d2..88fa97f 100644 --- a/tests/event_stream_test.py +++ b/tests/event_stream_test.py @@ -298,7 +298,13 @@ async def handle_message(message: Json, context: Any) -> None: # a new redis listener started later will receive all messages async with RedisStreamListener( - redis, "test-stream", "t1", "l1", handle_message, timedelta(seconds=5), backoff=Backoff(0, 0, 5) + redis, + "test-stream", + "t1", + "l1", + handle_message, + timedelta(seconds=5), + backoff=defaultdict(lambda: Backoff(0, 0, 5)), ): async with RedisStreamPublisher(redis, "test-stream", "test") as publisher: await publisher.publish("test_data", unstructure(ExampleData(1, "foo", [1, 2, 3])))