Skip to content

Commit

Permalink
Add a parallel mode to the stream listener (#16)
Browse files Browse the repository at this point in the history
* Add a parallel mode to the stream listener

* fix runtime error message check

* linter fix

* mypy fix

* raise the exception properly

* fix my comments

---------

Co-authored-by: Matthias Veit <[email protected]>
  • Loading branch information
meln1k and aquamatthias authored Oct 30, 2023
1 parent 3ad28f6 commit 7307bb0
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 13 deletions.
54 changes: 46 additions & 8 deletions fixcloudutils/redis/event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import sys
import uuid
from asyncio import Task
from collections import defaultdict
from contextlib import suppress
from datetime import datetime, timedelta
from functools import partial
Expand All @@ -45,10 +46,12 @@
TypeVar,
Dict,
List,
Set,
)

from attrs import define
from redis.asyncio import Redis
from redis.typing import StreamIdT

from fixcloudutils.asyncio import stop_running_task
from fixcloudutils.asyncio.periodic import Periodic
Expand All @@ -72,6 +75,7 @@ class Backoff:
base_delay: float
maximum_delay: float
retries: int
log_failed_attempts: bool = True

def wait_time(self, attempt: int) -> float:
delay: float = self.base_delay * (2**attempt + random.uniform(0, 1))
Expand All @@ -83,14 +87,16 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) -
except Exception as e:
if attempt < self.retries:
delay = self.wait_time(attempt)
log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}")
if self.log_failed_attempts:
log.warning(f"Got Exception in attempt {attempt}. Retry after {delay} seconds: {e}")
await asyncio.sleep(delay)
return await self.with_backoff(fn, attempt + 1)
else:
raise


NoBackoff = Backoff(0, 0, 0)
DefaultBackoff = Backoff(0.1, 10, 10)


@define(frozen=True, slots=True)
Expand Down Expand Up @@ -123,7 +129,8 @@ def __init__(
consider_failed_after: timedelta,
batch_size: int = 1000,
stop_on_fail: bool = False,
backoff: Optional[Backoff] = Backoff(0.1, 10, 10),
backoff: Optional[Dict[str, Backoff]] = None,
parallelism: Optional[int] = None,
) -> None:
"""
Create a RedisStream client.
Expand All @@ -137,7 +144,9 @@ def __init__(
:param consider_failed_after: The time after which a message is considered failed and will be retried.
:param batch_size: The number of events to read in one batch.
:param stop_on_fail: If True, the listener will stop if a failed event is retried too many times.
:param backoff: The backoff strategy to use when retrying failed events.
:param backoff: The backoff strategy for the defined message kind to use when retrying failed events.
The DefaultBackoff is used if no value is provided.
:param parallelism: If provided, messages will be processed in parallel without order.
"""
self.redis = redis
self.stream = stream
Expand All @@ -146,7 +155,7 @@ def __init__(
self.message_processor = message_processor
self.batch_size = batch_size
self.stop_on_fail = stop_on_fail
self.backoff = backoff or NoBackoff
self.backoff = defaultdict(lambda: DefaultBackoff) if backoff is None else backoff
self.__should_run = True
self.__listen_task: Optional[Task[Any]] = None
# Check for messages that are not processed for a long time by any listener. Try to claim and process them.
Expand All @@ -157,6 +166,8 @@ def __init__(
first_run=timedelta(seconds=3),
)
self.__readpos = ">"
self._ongoing_tasks: Set[Task[Any]] = set()
self.parallelism = parallelism

async def _listen(self) -> None:
while self.__should_run:
Expand All @@ -165,9 +176,13 @@ async def _listen(self) -> None:
self.group, self.listener, {self.stream: self.__readpos}, count=self.batch_size, block=1000
)
self.__readpos = ">"

await self._handle_stream_messages(messages)
if self.parallelism:
await self._handle_stream_messages_parallel(messages, self.parallelism)
else:
await self._handle_stream_messages(messages)
except Exception as e:
if isinstance(e, RuntimeError) and len(e.args) and e.args[0] == "no running event loop":
raise e
log.error(f"Failed to read from stream {self.stream}: {e}", exc_info=True)
if self.stop_on_fail:
raise
Expand All @@ -185,19 +200,41 @@ async def _handle_stream_messages(self, messages: List[Any]) -> None:
# acknowledge all processed messages
await self.redis.xack(self.stream, self.group, *ids)

async def _handle_stream_messages_parallel(self, messages: List[Any], max_parallelism: int) -> None:
"""
Handle messages in parallel in an unordered fashion. The number of parallel tasks is limited by max_parallelism.
"""

async def handle_and_ack(msg: Any, message_id: StreamIdT) -> None:
await self._handle_single_message(msg)
await self.redis.xack(self.stream, self.group, message_id)

def task_done_callback(task: Task[Any]) -> None:
self._ongoing_tasks.discard(task)

for stream, stream_messages in messages:
log.debug(f"Handle {len(stream_messages)} messages from stream.")
for uid, data in stream_messages:
while len(self._ongoing_tasks) >= max_parallelism: # queue is full, wait for a slot to be freed
await asyncio.wait(self._ongoing_tasks, return_when=asyncio.FIRST_COMPLETED)
task = asyncio.create_task(handle_and_ack(data, uid), name=f"handle_message_{uid}")
task.add_done_callback(task_done_callback)
self._ongoing_tasks.add(task)

async def _handle_single_message(self, message: Json) -> None:
try:
if "id" in message and "at" in message and "data" in message:
kind = message["kind"]
context = MessageContext(
id=message["id"],
kind=message["kind"],
kind=kind,
publisher=message["publisher"],
sent_at=parse_utc_str(message["at"]),
received_at=utc(),
)
data = json.loads(message["data"])
log.debug(f"Received message {self.listener}: message {context} data: {data}")
await self.backoff.with_backoff(partial(self.message_processor, data, context))
await self.backoff[kind].with_backoff(partial(self.message_processor, data, context))
else:
log.warning(f"Invalid message format: {message}. Ignore.")
except Exception as e:
Expand Down Expand Up @@ -271,6 +308,7 @@ async def read_all() -> None:
await self.__outdated_messages_task.start()

async def stop(self) -> Any:
await asyncio.gather(*[stop_running_task(task) for task in self._ongoing_tasks])
self.__should_run = False
await self.__outdated_messages_task.stop()
await stop_running_task(self.__listen_task)
Expand Down
10 changes: 6 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,24 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from typing import List
from typing import List, AsyncIterator

from arango.client import ArangoClient
from attr import define
from pytest import fixture
from redis.asyncio import Redis
from redis.backoff import ExponentialBackoff
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff

from fixcloudutils.arangodb.async_arangodb import AsyncArangoDB


@fixture
def redis() -> Redis:
async def redis() -> AsyncIterator[Redis]:
backoff = ExponentialBackoff() # type: ignore
return Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10))
redis = Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10))
yield redis
await redis.close(True)


@fixture
Expand Down
124 changes: 123 additions & 1 deletion tests/event_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,122 @@ async def check_all_arrived(expected_reader: int) -> bool:
await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq")


@pytest.mark.asyncio
@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running")
async def test_stream_parallel(redis: Redis) -> None:
counter: List[int] = [0]

async def handle_message(group: int, uid: int, message: Json, _: MessageContext) -> None:
# make sure we can read the message
data = structure(message, ExampleData)
assert data.bar == "foo"
assert data.bla == [1, 2, 3]
await asyncio.sleep(0.5) # message takes time to be processed
counter[0] += 1

# clean slate
await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq")

# create a single listener
stream = RedisStreamListener(
redis, "test-stream", "group", "id", partial(handle_message, 1, 1), timedelta(seconds=1), parallelism=10
)
await stream.start()

messages_total = 10
# publish 10 messages
publisher = RedisStreamPublisher(redis, "test-stream", "test")
for i in range(messages_total):
await publisher.publish("test_data", unstructure(ExampleData(i, "foo", [1, 2, 3])))

# make sure messages are in the stream
assert (await redis.xlen("test-stream")) == messages_total

# expect 10 messages per listener --> 100 messages
async def check_all_arrived(expected_reader: int) -> bool:
while True:
if counter[0] == expected_reader:
return True
await asyncio.sleep(0.1)

# processing must be parallel and we won't hit a timeout error
# if the parallelism is not working then the processing will take 5 seconds
# and the test will fail
await asyncio.wait_for(check_all_arrived(messages_total), timeout=2)

# messages must be acked and not be processed again
await asyncio.sleep(1)
assert counter[0] == messages_total

# no tasks should be running once everything is processed
assert len(stream._ongoing_tasks) == 0

# stop all listeners
await stream.stop()

# don't leave any traces
await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq")


@pytest.mark.asyncio
@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running")
async def test_stream_parallel_backpressure(redis: Redis) -> None:
counter: List[int] = [0]

async def handle_message(group: int, uid: int, message: Json, _: MessageContext) -> None:
# make sure we can read the message
data = structure(message, ExampleData)
assert data.bar == "foo"
assert data.bla == [1, 2, 3]
await asyncio.sleep(0.15) # message takes time to be processed
counter[0] += 1

# clean slate
await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq")

# create a single listener
stream = RedisStreamListener(
redis, "test-stream", "group", "id", partial(handle_message, 1, 1), timedelta(seconds=1), parallelism=1
)
await stream.start()

messages_total = 10
# publish 10 messages
publisher = RedisStreamPublisher(redis, "test-stream", "test")
for i in range(messages_total):
await publisher.publish("test_data", unstructure(ExampleData(i, "foo", [1, 2, 3])))

# make sure messages are in the stream
assert (await redis.xlen("test-stream")) == messages_total

# expect 10 messages per listener --> 100 messages
async def check_all_arrived(expected_reader: int) -> bool:
while True:
if counter[0] == expected_reader:
return True
await asyncio.sleep(0.1)

# if the parallelism is full we should wait before enqueueing the next message
# the total processing time should at least be 1.5 seconds (10 messages * 0.15 seconds)
before = asyncio.get_running_loop().time()
await asyncio.wait_for(check_all_arrived(messages_total), timeout=2)
after = asyncio.get_running_loop().time()
assert after - before >= 1.5

# messages must be acked and not be processed again
await asyncio.sleep(1)
assert counter[0] == messages_total

# no tasks should be running once everything is processed
assert len(stream._ongoing_tasks) == 0

# stop all listeners
await stream.stop()

# don't leave any traces
await redis.delete("test-stream", "test-stream.listener", "test-stream.dlq")


@pytest.mark.asyncio
@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running")
async def test_stream_pending(redis: Redis) -> None:
Expand Down Expand Up @@ -182,7 +298,13 @@ async def handle_message(message: Json, context: Any) -> None:

# a new redis listener started later will receive all messages
async with RedisStreamListener(
redis, "test-stream", "t1", "l1", handle_message, timedelta(seconds=5), backoff=Backoff(0, 0, 5)
redis,
"test-stream",
"t1",
"l1",
handle_message,
timedelta(seconds=5),
backoff=defaultdict(lambda: Backoff(0, 0, 5)),
):
async with RedisStreamPublisher(redis, "test-stream", "test") as publisher:
await publisher.publish("test_data", unstructure(ExampleData(1, "foo", [1, 2, 3])))
Expand Down

0 comments on commit 7307bb0

Please sign in to comment.