Skip to content

Commit

Permalink
break(AsyncGenerator): Stream as async generators (yield from stream)…
Browse files Browse the repository at this point in the history
… are only available with the new typing approach (#157)
  • Loading branch information
marcosschroh authored Jan 11, 2024
1 parent dcf20f5 commit 6c0cbf9
Show file tree
Hide file tree
Showing 12 changed files with 390 additions and 266 deletions.
8 changes: 5 additions & 3 deletions docs/stream.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ To facilitate the process, we have `context manager` that makes sure of the `sta
```python title="Yield example"
# Create your stream
@stream_engine.stream("local--kstream")
async def stream(stream: Stream):
async for cr in stream:
yield cr.value
async def stream(cr: ConsumerRecord, stream: Stream):
yield cr.value


# Consume the stream:
Expand All @@ -258,6 +257,9 @@ async with stream as stream_flow: # Use the context manager
If for some reason you interrupt the "async for in" in the async generator, the Stream will stopped consuming events
meaning that the lag will increase.

!!! note
Yield from a stream only works with the [typing approach](https://kpn.github.io/kstreams/stream/#dependency-injection-and-typing)

## Get many

::: kstreams.streams.Stream.getmany
Expand Down
2 changes: 1 addition & 1 deletion examples/fastapi-sse/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ poetry install
## Usage

1. Start the kafka cluster: From `kstreams` project root execute `./scripts/cluster/start`
2. Start the `FastAPI webserver`. Inside the `fastapi-sse` folder execute `poetry run python -m fastapi_sse`
2. Start the `FastAPI webserver`. Inside the `fastapi-sse` folder execute `poetry run app`
3. Consume events from the topic with `fastapi-sse`: `curl http://localhost:8000/topics/local--sse/group-1/`. If everything worked, you should see a log similar to the following one where the `webserever` is running:
```bash
INFO: 127.0.0.1:51060 - "GET /topics/local--sse/group-1/ HTTP/1.1" 200 OK
Expand Down
7 changes: 6 additions & 1 deletion examples/fastapi-sse/fastapi_sse/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import uvicorn

if __name__ == "__main__":

def main():
uvicorn.run(
app="fastapi_sse.app:app",
host="localhost",
Expand All @@ -11,3 +12,7 @@
reload=True,
debug=True,
)


if __name__ == "__main__":
main()
11 changes: 6 additions & 5 deletions examples/fastapi-sse/fastapi_sse/streaming/streams.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from kstreams import Stream
from typing import Optional

from kstreams import ConsumerRecord, Stream

from .engine import stream_engine


def stream_factory(
*, topic: str, group_id: str = None, auto_offset_reset: str = "latest"
*, topic: str, group_id: Optional[str] = None, auto_offset_reset: str = "latest"
):
async def stream_func(stream: Stream):
async for cr in stream:
yield cr.value
async def stream_func(cr: ConsumerRecord):
yield cr.value

s = Stream(
topic,
Expand Down
402 changes: 236 additions & 166 deletions examples/fastapi-sse/poetry.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions examples/fastapi-sse/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ authors = ["Marcos Schroh <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
sse-starlette = "^0.10.3"
sse-starlette = "^1.8.2"
kstreams = { path = "../../.", develop = true }
uvicorn = "^0.18.2"
fastapi = "^0.78.0"
fastapi = "^0.108.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
app = "fastapi_sse.__main__:main"
63 changes: 31 additions & 32 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import update_wrapper
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Expand All @@ -30,8 +29,7 @@

logger = logging.getLogger(__name__)

# Function required by the `stream` decorator
StreamFunc = Callable[..., Awaitable[Any]]
StreamFunc = Callable


class Stream:
Expand Down Expand Up @@ -112,6 +110,7 @@ def __init__(
self.initial_offsets = initial_offsets
self.seeked_initial_offsets = False
self.rebalance_listener = rebalance_listener
self.udf_type = inspect_udf(func, Stream)

# aiokafka expects topic names as arguments, meaning that
# can receive N topics -> N arguments,
Expand Down Expand Up @@ -203,30 +202,29 @@ async def stream(stream: Stream):
*partitions, timeout_ms=timeout_ms, max_records=max_records
)

async def start(self) -> Optional[AsyncGenerator]:
async def start(self) -> None:
if self.running:
return None

await self._subscribe()
udf_type = inspect_udf(self.func, Stream)

if udf_type == UDFType.NO_TYPING:
if self.udf_type == UDFType.NO_TYPING:
# normal use case
logging.warn(
"Streams with `async for in` loop approach might be deprecated. "
"Consider migrating to a typing approach."
)

func = self.func(self)
if inspect.isasyncgen(func):
return func
else:
# It is not an async_generator so we need to
# create an asyncio.Task with func
logging.warning(
"Streams with `async for in` loop approach might be deprecated. "
"Consider migrating to a typing approach."
)
self._consumer_task = asyncio.create_task(self.func_wrapper(func))
# create an asyncio.Task with func
self._consumer_task = asyncio.create_task(self.func_wrapper(func))
else:
self._consumer_task = asyncio.create_task(
self.func_wrapper_with_typing(udf_type)
)
# Typing cases
if not inspect.isasyncgenfunction(self.func):
# Is not an async_generator, then create an asyncio.Task with func
self._consumer_task = asyncio.create_task(
self.func_wrapper_with_typing()
)
return None

async def func_wrapper(self, func: Awaitable) -> None:
Expand All @@ -238,14 +236,14 @@ async def func_wrapper(self, func: Awaitable) -> None:
except Exception as e:
logger.exception(f"CRASHED Stream!!! Task {self._consumer_task} \n\n {e}")

async def func_wrapper_with_typing(self, calling_type: UDFType) -> None:
async def func_wrapper_with_typing(self) -> None:
try:
# await for the end user coroutine
# we do this to show a better error message to the user
# when the coroutine fails
while True:
cr = await self.getone()
if calling_type == UDFType.CR_ONLY_TYPING:
if self.udf_type == UDFType.CR_ONLY_TYPING:
await self.func(cr)
else:
# typing with cr and stream
Expand Down Expand Up @@ -279,16 +277,15 @@ def seek_to_initial_offsets(self):
)
self.seeked_initial_offsets = True

async def __aenter__(self) -> AsyncGenerator:
async def __aenter__(self) -> "Stream":
"""
Start the kafka Consumer and return an `async_gen` so it can be iterated
!!! Example
```python title="Usage"
@stream_engine.stream(topic, group_id=group_id, ...)
async def stream(stream):
async for cr in stream:
yield cr.value
yield cr.value
# Iterate the stream:
Expand All @@ -298,9 +295,8 @@ async def stream(stream):
```
"""
logger.info("Starting async_gen Stream....")
async_gen = await self.start()
# For now ignoring the typing issue. The start method might be splited
return async_gen # type: ignore
await self.start()
return self

async def __aexit__(self, exc_type, exc, tb) -> None:
logger.info("Stopping async_gen Stream....")
Expand All @@ -310,12 +306,15 @@ def __aiter__(self):
return self

async def __anext__(self) -> ConsumerRecord:
# This will be used only with async generators
if not self.running:
await self.start()

try:
return await self.getone()
cr = await self.getone()

if self.udf_type == UDFType.NO_TYPING:
return cr
elif self.udf_type == UDFType.CR_ONLY_TYPING:
return await anext(self.func(cr))
else:
return await anext(self.func(cr, self))
except errors.ConsumerStoppedError:
raise StopAsyncIteration # noqa: F821

Expand Down
2 changes: 1 addition & 1 deletion kstreams/streams_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class UDFType(str, enum.Enum):
ALL_TYPING = "ALL_TYPING"


def inspect_udf(func: Callable[..., Any], a_type: Any) -> UDFType:
def inspect_udf(func: Callable, a_type: Any) -> UDFType:
"""
Inspect the user defined function (coroutine) to get the proper way to call it
Expand Down
6 changes: 5 additions & 1 deletion kstreams/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Dict, Sequence, Tuple
from typing import (
Dict,
Sequence,
Tuple,
)

Headers = Dict[str, str]
EncodedHeaders = Sequence[Tuple[str, bytes]]
94 changes: 94 additions & 0 deletions tests/test_async_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Callable
from unittest import mock

import pytest

from kstreams import ConsumerRecord, Stream, StreamEngine, TestStreamClient
from kstreams.clients import Consumer


@pytest.mark.asyncio
async def test_add_stream_as_generator(
stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord]
):
@stream_engine.stream("local--hello-kpn")
async def stream(cr: ConsumerRecord):
yield cr

assert stream == stream_engine._streams[0]
assert not stream.running

cr = consumer_record_factory()

async def getone(_):
return cr

with mock.patch.multiple(Consumer, start=mock.DEFAULT, getone=getone):
# simulate an engine start
await stream.start()

# Now the stream should be running as we are in the context
assert stream.running
async for value in stream:
assert value == cr
break


@pytest.mark.asyncio
async def test_stream_consume_events_as_generator_cr_typing(
stream_engine: StreamEngine,
):
topic = "local--hello-kpn"
event = b'{"message": "Hello world!"}'
client = TestStreamClient(stream_engine)
save_to_db = mock.Mock()

@stream_engine.stream(topic)
async def stream(cr: ConsumerRecord):
save_to_db(cr.value)
yield cr

async with client:
await client.send(topic, value=event, key="1")

async with stream as stream_flow:
async for cr in stream_flow:
assert cr.value == event
break

# we left the stream context, so it has stopped
assert not stream.running

# check that the event was consumed
save_to_db.assert_called_once_with(event)


@pytest.mark.asyncio
async def test_stream_consume_events_as_generator_all_typing(
stream_engine: StreamEngine,
):
topic = "local--hello-kpn"
event = b'{"message": "Hello world!"}'
client = TestStreamClient(stream_engine)
save_to_db = mock.Mock()

@stream_engine.stream(topic)
async def my_stream(cr: ConsumerRecord, stream: Stream):
save_to_db(cr.value)
# commit the event. Not ideal but we want to prove that works
await stream.commit()
yield cr

async with client:
await client.send(topic, value=event, key="1")

async with my_stream as stream_flow:
async for cr in stream_flow:
assert cr.value == event
break

# we left the stream context, so it has stopped
assert not my_stream.running

# check that the event was consumed
save_to_db.assert_called_once_with(event)
25 changes: 0 additions & 25 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,6 @@ async def stream(stream: Stream):
save_to_db.assert_called_once_with([event for _ in range(0, max_records)])


@pytest.mark.asyncio
async def test_stream_consume_events_as_generator(stream_engine: StreamEngine):
topic = "local--hello-kpn"
event = b'{"message": "Hello world!"}'
client = TestStreamClient(stream_engine)
save_to_db = Mock()

@stream_engine.stream(topic)
async def my_stream(stream: Stream):
async for cr in stream:
save_to_db(cr.value)
yield cr

async with client:
await client.send(topic, value=event, key="1")

async with my_stream as processor:
async for cr in processor:
assert cr.value == event
break

# check that the event was consumed
save_to_db.assert_called_once_with(event)


@pytest.mark.asyncio
async def test_stream_func_with_cr(stream_engine: StreamEngine):
client = TestStreamClient(stream_engine)
Expand Down
Loading

0 comments on commit 6c0cbf9

Please sign in to comment.