Skip to content

Commit

Permalink
fix: add commit and commited functions to consumer test client. Closes
Browse files Browse the repository at this point in the history
…#61 (#63)

Co-authored-by: Marcos Schroh <[email protected]>
  • Loading branch information
marcosschroh and marcosschroh authored Oct 5, 2022
1 parent 43faad1 commit 66f8ee5
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 2 deletions.
42 changes: 42 additions & 0 deletions docs/test_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,48 @@ async def test_add_event_on_consume():
The `TestStreamClient.send` coroutine is used instead.
This allows to test `streams` without having producer code in your application

### Testing the Commit

In some cases your stream will commit, in this situation checking the commited partitions can be useful.

```python
import pytest
from kstreams.test_utils import TestStreamClient

from .example import produce, stream_engine

topic_name = "local--kstreams-marcos"
value = b'{"message": "Hello world!"}'
name = "my-stream"
key = "1"
partition = 2
tp = structs.TopicPartition(
topic=topic_name,
partition=partition,
)
total_events = 10

@stream_engine.stream(topic_name, name=name)
async def my_stream(stream: Stream):
async for cr in stream:
# commit every time that an event arrives
await stream.consumer.commit({tp: cr.offset})


# test the code
client = TestStreamClient(stream_engine)

@pytest.mark.asyncio
async def test_consumer_commit(stream_engine: StreamEngine):
async with client:
for _ in range(0, total_events):
await client.send(topic_name, partition=partition, value=value, key=key)

# check that everything was commited
stream = stream_engine.get_stream(name)
assert (await stream.consumer.committed(tp)) == total_events
```

### E2E test

In the previous code example the application produces to and consumes from the same topic, then `TestStreamClient.send` is not needed because the `engine.send` is producing. For those situation you can just use your `producer` code and check that certain code was called.
Expand Down
2 changes: 1 addition & 1 deletion kstreams/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def send(
topic: str,
value: Any = None,
key: Any = None,
partition: Optional[str] = None,
partition: Optional[int] = None,
timestamp_ms: Optional[int] = None,
headers: Optional[Headers] = None,
serializer: Optional[Serializer] = None,
Expand Down
3 changes: 3 additions & 0 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ async def stop(self) -> None:
self._consumer_task.cancel()

async def start(self) -> Optional[AsyncGenerator]:
if self.running:
return None

async def func_wrapper(func):
try:
# await for the end user coroutine
Expand Down
10 changes: 10 additions & 0 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, *topics: str, group_id: Optional[str] = None, **kwargs) -> No
self.topics: Tuple[str, ...] = topics
self._group_id: Optional[str] = group_id
self._assigments: List[TopicPartition] = []
self.partitions_committed: Dict[TopicPartition, int] = {}

for topic_name in topics:
TopicManager.create(topic_name, consumer=self)
Expand All @@ -90,6 +91,15 @@ async def position(self, topic_partition: TopicPartition) -> int:
def highwater(self, topic_partition: TopicPartition) -> int:
return self.last_stable_offset(topic_partition)

async def commit(self, offsets: Optional[Dict[TopicPartition, int]] = None) -> None:
if offsets is not None:
for topic_partition, offset in offsets.items():
self.partitions_committed[topic_partition] = offset
return None

async def committed(self, topic_partition: TopicPartition) -> Optional[int]:
return self.partitions_committed.get(topic_partition)

async def getone(
self,
) -> Optional[ConsumerRecord]: # The return type must be fixed later on
Expand Down
2 changes: 1 addition & 1 deletion kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def send(
topic: str,
value: Any = None,
key: Optional[Any] = None,
partition: Optional[str] = None,
partition: Optional[int] = None,
timestamp_ms: Optional[int] = None,
headers: Optional[Headers] = None,
serializer: Optional[Serializer] = None,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from kstreams import StreamEngine
from kstreams.streams import Stream
from kstreams.test_utils import (
TestConsumer,
TestProducer,
Expand Down Expand Up @@ -87,6 +88,34 @@ async def test_topic_created(stream_engine: StreamEngine):
assert consumer_record.key == key


@pytest.mark.asyncio
async def test_consumer_commit(stream_engine: StreamEngine):
topic_name = "local--kstreams-marcos"
value = b'{"message": "Hello world!"}'
name = "my-stream"
key = "1"
partition = 2
tp = structs.TopicPartition(
topic=topic_name,
partition=partition,
)
total_events = 10

@stream_engine.stream(topic_name, name=name)
async def my_stream(stream: Stream):
async for cr in stream:
await stream.consumer.commit({tp: cr.offset})

client = TestStreamClient(stream_engine)
async with client:
for _ in range(0, total_events):
await client.send(topic_name, partition=partition, value=value, key=key)

# check that everything was commited
stream = stream_engine.get_stream(name)
assert (await stream.consumer.committed(tp)) == total_events


@pytest.mark.asyncio
async def test_e2e_example():
"""
Expand Down

0 comments on commit 66f8ee5

Please sign in to comment.