Skip to content

Commit

Permalink
[feat] Improve Reader API (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Sep 7, 2023
1 parent beb846c commit 7bb829d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 13 deletions.
27 changes: 20 additions & 7 deletions fixcloudutils/redis/event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from fixcloudutils.asyncio import stop_running_task
from fixcloudutils.asyncio.periodic import Periodic
from fixcloudutils.service import Service
from fixcloudutils.util import utc_str
from fixcloudutils.util import utc_str, parse_utc_str, utc

log = logging.getLogger("fix.event_stream")
T = TypeVar("T")
Expand Down Expand Up @@ -82,6 +82,15 @@ async def with_backoff(self, fn: Callable[[], Awaitable[T]], attempt: int = 0) -
NoBackoff = Backoff(0, 0, 0)


@define(frozen=True, slots=True)
class MessageContext:
id: str
kind: str
publisher: str
sent_at: datetime
received_at: datetime


class RedisStreamListener(Service):
"""
Allows processing of messages from a redis stream in a group of readers.
Expand All @@ -99,7 +108,7 @@ def __init__(
stream: str,
group: str,
listener: str,
message_processor: Callable[[Json], Union[Awaitable[Any], Any]],
message_processor: Callable[[Json, MessageContext], Union[Awaitable[Any], Any]],
consider_failed_after: timedelta,
batch_size: int = 1000,
stop_on_fail: bool = False,
Expand Down Expand Up @@ -168,12 +177,16 @@ async def _handle_stream_messages(self, messages: List[Any]) -> None:
async def _handle_single_message(self, message: Json) -> None:
try:
if "id" in message and "at" in message and "data" in message:
mid = message["id"]
at = message["at"]
publisher = message["publisher"]
context = MessageContext(
id=message["id"],
kind=message["kind"],
publisher=message["publisher"],
sent_at=parse_utc_str(message["at"]),
received_at=utc(),
)
data = json.loads(message["data"])
log.debug(f"Received message {self.listener}: message {mid}, from {publisher}, at {at} data: {data}")
await self.backoff.with_backoff(partial(self.message_processor, data))
log.debug(f"Received message {self.listener}: message {context} data: {data}")
await self.backoff.with_backoff(partial(self.message_processor, data, context))
else:
log.warning(f"Invalid message format: {message}. Ignore.")
except Exception as e:
Expand Down
62 changes: 61 additions & 1 deletion fixcloudutils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime, timezone
from typing import Optional, TypeVar
from typing import Optional, TypeVar, Union, List, Any

from fixcloudutils.types import JsonElement, Json

T = TypeVar("T")
UTC_Date_Format = "%Y-%m-%dT%H:%M:%SZ"
Expand All @@ -35,3 +37,61 @@ def parse_utc_str(s: str) -> datetime:

def identity(o: T) -> T:
return o


def value_in_path_get(element: JsonElement, path_or_name: Union[List[str], str], if_none: T) -> T:
result = value_in_path(element, path_or_name)
return result if result is not None and isinstance(result, type(if_none)) else if_none


def value_in_path(element: JsonElement, path_or_name: Union[List[str], str]) -> Optional[Any]:
path = path_or_name if isinstance(path_or_name, list) else path_or_name.split(".")
at = len(path)

def at_idx(current: JsonElement, idx: int) -> Optional[Any]:
if at == idx:
return current
elif current is None or not isinstance(current, dict) or path[idx] not in current:
return None
else:
return at_idx(current[path[idx]], idx + 1)

return at_idx(element, 0)


def set_value_in_path(element: JsonElement, path_or_name: Union[List[str], str], js: Optional[Json] = None) -> Json:
path = path_or_name if isinstance(path_or_name, list) else path_or_name.split(".")
at = len(path) - 1

def at_idx(current: Json, idx: int) -> None:
if at == idx:
current[path[-1]] = element
else:
value = current.get(path[idx])
if not isinstance(value, dict):
value = {}
current[path[idx]] = value
at_idx(value, idx + 1)

js = js if js is not None else {}
at_idx(js, 0)
return js


def del_value_in_path(element: JsonElement, path_or_name: Union[List[str], str]) -> JsonElement:
path = path_or_name if isinstance(path_or_name, list) else path_or_name.split(".")
pl = len(path) - 1

def at_idx(current: JsonElement, idx: int) -> JsonElement:
if current is None or not isinstance(current, dict) or path[idx] not in current:
return element
elif pl == idx:
current.pop(path[-1], None)
return element
else:
result = at_idx(current[path[idx]], idx + 1)
if not current[path[idx]]:
current[path[idx]] = None
return result

return at_idx(element, 0)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fixcloudutils"
version = "1.3.0"
version = "1.4.0"
authors = [{ name = "Some Engineering Inc." }]
description = "Utilities for fixcloud."
license = { file = "LICENSE" }
Expand Down
9 changes: 5 additions & 4 deletions tests/event_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from contextlib import suppress
from datetime import timedelta
from functools import partial
from typing import Dict, Tuple, List
from typing import Dict, Tuple, List, Any

import pytest
from cattrs import unstructure, structure
Expand All @@ -31,6 +31,7 @@
RedisStreamPublisher,
Backoff,
RedisStreamListener,
MessageContext,
)
from fixcloudutils.types import Json

Expand All @@ -41,7 +42,7 @@ async def test_stream(redis: Redis) -> None:
group_counter: Dict[int, int] = defaultdict(int)
listener_counter: Dict[Tuple[int, int], int] = defaultdict(int)

async def handle_message(group: int, uid: int, message: Json) -> None:
async def handle_message(group: int, uid: int, message: Json, _: MessageContext) -> None:
# make sure we can read the message
data = structure(message, ExampleData)
assert data.bar == "foo"
Expand Down Expand Up @@ -141,7 +142,7 @@ async def check_all_arrived(num_messages: int) -> bool:
return True
await asyncio.sleep(0.1)

async def handle_message(message: Json) -> None:
async def handle_message(message: Json, context: Any) -> None:
arrived_messages.append(message)

listener = RedisStreamListener(redis, "test-stream", "foo", "bar", handle_message, timedelta(seconds=5))
Expand All @@ -160,7 +161,7 @@ async def handle_message(message: Json) -> None:
async def test_failure(redis: Redis) -> None:
counter = 0

async def handle_message(_: Json) -> None:
async def handle_message(message: Json, context: Any) -> None:
nonlocal counter
counter += 1
raise Exception("boom")
Expand Down

0 comments on commit 7bb829d

Please sign in to comment.