diff --git a/kstreams/test_utils/test_clients.py b/kstreams/test_utils/test_clients.py index 112faa69..bcfe6d82 100644 --- a/kstreams/test_utils/test_clients.py +++ b/kstreams/test_utils/test_clients.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Coroutine, Dict, List, Optional, Tuple +from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple from aiokafka.structs import ConsumerRecord @@ -119,6 +119,17 @@ async def commit(self, offsets: Optional[Dict[TopicPartition, int]] = None) -> N async def committed(self, topic_partition: TopicPartition) -> Optional[int]: return self.partitions_committed.get(topic_partition) + def partitions_for_topic(self, topic: str) -> Set: + """ + Return the partitions of all assigned topics. The `topic` argument is not used + because in a testing enviroment the only topics are the ones declared by the end + user. + + The AIOKafkaConsumer returns a Set, so we do the same. + """ + partitions = [topic_partition.partition for topic_partition in self._assignment] + return set(partitions) + async def getone( self, ) -> Optional[ConsumerRecord]: # The return type must be fixed later on diff --git a/tests/test_client.py b/tests/test_client.py index 49de2abf..1ca0054a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -145,6 +145,28 @@ async def test_clean_up_events(stream_engine: StreamEngine): assert not TopicManager.topics +@pytest.mark.asyncio +async def test_partitions_for_topic(stream_engine: StreamEngine): + topic_name = "local--kstreams" + value = b'{"message": "Hello world!"}' + key = "1" + client = TestStreamClient(stream_engine) + + @stream_engine.stream(topic_name, name="my-stream") + async def consume(stream): + async for cr in stream: + ... + + async with client: + # produce to events and consume only one in the client context + await client.send(topic_name, value=value, key=key) + await client.send(topic_name, value=value, key=key, partition=2) + await client.send(topic_name, value=value, key=key, partition=10) + + stream = stream_engine.get_stream("my-stream") + assert stream.consumer.partitions_for_topic(topic_name) == set([0, 2, 10]) + + @pytest.mark.asyncio async def test_consumer_commit(stream_engine: StreamEngine): topic_name = "local--kstreams-consumer-commit"