From 9117c4b69f83f0fc11dda1f529e2515620d6d090 Mon Sep 17 00:00:00 2001 From: Marcos Schroh <2828842+marcosschroh@users.noreply.github.com> Date: Tue, 11 Oct 2022 08:57:30 +0200 Subject: [PATCH] fix: TestConsumer partition assigments, TestProducer consumer record and record metadata (#66) Co-authored-by: Marcos Schroh --- examples/simple.py | 2 +- kstreams/test_utils/test_clients.py | 41 ++++++++++++++++++++++------- kstreams/test_utils/topics.py | 34 +++++++++++++++++++----- tests/test_client.py | 23 ++++++++++------ 4 files changed, 74 insertions(+), 26 deletions(-) diff --git a/examples/simple.py b/examples/simple.py index 514064f9..0f23476d 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -5,7 +5,7 @@ from kstreams import ConsumerRecord, create_engine from kstreams.streams import Stream -topic = "local--kstreams" +topic = "local--kstreams-test" stream_engine = create_engine(title="my-stream-engine") diff --git a/kstreams/test_utils/test_clients.py b/kstreams/test_utils/test_clients.py index 992bb924..54852abe 100644 --- a/kstreams/test_utils/test_clients.py +++ b/kstreams/test_utils/test_clients.py @@ -30,7 +30,9 @@ async def send( ) -> Coroutine: topic = TopicManager.get_or_create(topic_name) timestamp_ms = timestamp_ms or datetime.now().timestamp() - total_messages = topic.total_messages + 1 + total_partition_events = ( + topic.get_total_partition_events(partition=partition) + 1 + ) consumer_record = ConsumerRecord( topic=topic_name, @@ -39,7 +41,7 @@ async def send( headers=headers, partition=partition, timestamp=timestamp_ms, - offset=total_messages, + offset=total_partition_events, timestamp_type=None, checksum=None, serialized_key_size=None, @@ -53,7 +55,7 @@ async def fut(): topic=topic_name, partition=partition, timestamp=timestamp_ms, - offset=total_messages, + offset=total_partition_events, ) return fut() @@ -64,25 +66,42 @@ def __init__(self, *topics: str, group_id: Optional[str] = None, **kwargs) -> No # copy the aiokafka behavior self.topics: Tuple[str, ...] = topics self._group_id: Optional[str] = group_id - self._assigments: List[TopicPartition] = [] + self._assignment: List[TopicPartition] = [] self.partitions_committed: Dict[TopicPartition, int] = {} for topic_name in topics: - TopicManager.create(topic_name, consumer=self) - self._assigments.append(TopicPartition(topic=topic_name, partition=1)) + TopicManager.get_or_create(topic_name, consumer=self) + self._assignment.append(TopicPartition(topic=topic_name, partition=1)) # Called to make sure that has all the kafka attributes like _coordinator # so it will behave like an real Kafka Consumer super().__init__() def assignment(self) -> List[TopicPartition]: - return self._assigments + return self._assignment + + def _check_partition_assignments(self, consumer_record: ConsumerRecord) -> None: + """ + When an event is consumed the partition can be any positive int number + because there is not limit in the producer side (only during testing of course). + In case that the partition is not in the `_assignment` we need to register it. + + This is only during testing as in real use cases the assignments happens + at the moment of kafka bootstrapping + """ + topic_partition = TopicPartition( + topic=consumer_record.topic, + partition=consumer_record.partition, + ) + + if topic_partition not in self._assignment: + self._assignment.append(topic_partition) def last_stable_offset(self, topic_partition: TopicPartition) -> int: topic = TopicManager.get(topic_partition.topic) if topic is not None: - return topic.total_messages + return topic.get_total_partition_events(partition=topic_partition.partition) return -1 async def position(self, topic_partition: TopicPartition) -> int: @@ -104,12 +123,14 @@ async def getone( self, ) -> Optional[ConsumerRecord]: # The return type must be fixed later on topic = None - for topic_partition in self._assigments: + for topic_partition in self._assignment: topic = TopicManager.get(topic_partition.topic) if not topic.consumed: break if topic is not None: - return await topic.get() + consumer_record = await topic.get() + self._check_partition_assignments(consumer_record) + return consumer_record return None diff --git a/kstreams/test_utils/topics.py b/kstreams/test_utils/topics.py index 32678cd9..180c79af 100644 --- a/kstreams/test_utils/topics.py +++ b/kstreams/test_utils/topics.py @@ -1,6 +1,7 @@ import asyncio -from dataclasses import dataclass -from typing import ClassVar, Dict, Optional +from collections import defaultdict +from dataclasses import dataclass, field +from typing import ClassVar, DefaultDict, Dict, Optional from kstreams import ConsumerRecord @@ -11,13 +12,19 @@ class Topic: name: str queue: asyncio.Queue - total_messages: int = 0 + total_partition_events: DefaultDict[int, int] = field( + default_factory=lambda: defaultdict(int) + ) + total_events: int = 0 # for now we assumed that 1 streams is connected to 1 topic consumer: Optional["test_clients.Consumer"] = None async def put(self, event: ConsumerRecord) -> None: await self.queue.put(event) - self.total_messages += 1 + + # keep track of the amount of events per topic partition + self.total_partition_events[event.partition] += 1 + self.total_events += 1 async def get(self) -> ConsumerRecord: return await self.queue.get() @@ -25,6 +32,12 @@ async def get(self) -> ConsumerRecord: def is_empty(self) -> bool: return self.queue.empty() + def size(self) -> int: + return self.queue.qsize() + + def get_total_partition_events(self, *, partition: int) -> int: + return self.total_partition_events[partition] + @property def consumed(self) -> bool: """ @@ -54,7 +67,12 @@ def get(cls, name: str) -> Topic: if topic is not None: return topic - raise ValueError(f"Topic {name} not found") + raise ValueError( + f"You might be trying to get the topic {name} outside the " + "`client async context` or trying to geh an event from an empty " + f"topic {name}. Make sure that the code is inside the async context" + "and the topic has events." + ) @classmethod def create( @@ -65,14 +83,16 @@ def create( return topic @classmethod - def get_or_create(cls, name: str) -> Topic: + def get_or_create( + cls, name: str, consumer: Optional["test_clients.Consumer"] = None + ) -> Topic: """ Add a new queue if does not exist and return it """ try: topic = cls.get(name) except ValueError: - topic = cls.create(name) + topic = cls.create(name, consumer=consumer) return topic @classmethod diff --git a/tests/test_client.py b/tests/test_client.py index f0c559a3..355a98de 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -35,23 +35,30 @@ async def test_send_event_with_test_client(stream_engine: StreamEngine): metadata = await client.send( topic, value=b'{"message": "Hello world!"}', key="1" ) - current_offset = metadata.offset + assert metadata.topic == topic assert metadata.partition == 1 + assert metadata.offset == 1 # send another event and check that the offset was incremented metadata = await client.send( topic, value=b'{"message": "Hello world!"}', key="1" ) - assert metadata.offset == current_offset + 1 + assert metadata.offset == 2 + + # send en event to a different partition + metadata = await client.send( + topic, value=b'{"message": "Hello world!"}', key="1", partition=2 + ) + + # because it is a different partition the offset should be 1 + assert metadata.offset == 1 @pytest.mark.asyncio async def test_streams_consume_events(stream_engine: StreamEngine): - from examples.simple import stream_engine - client = TestStreamClient(stream_engine) - topic = "local--kstreams-2" + topic = "local--kstreams-consumer" event = b'{"message": "Hello world!"}' tp = structs.TopicPartition(topic=topic, partition=1) save_to_db = Mock() @@ -91,7 +98,7 @@ async def test_topic_created(stream_engine: StreamEngine): @pytest.mark.asyncio async def test_consumer_commit(stream_engine: StreamEngine): - topic_name = "local--kstreams-marcos" + topic_name = "local--kstreams-consumer-commit" value = b'{"message": "Hello world!"}' name = "my-stream" key = "1" @@ -151,7 +158,7 @@ async def test_e2e_consume_multiple_topics(): topic_1 = TopicManager.get(topics[0]) topic_2 = TopicManager.get(topics[1]) - assert topic_1.total_messages == events_per_topic - assert topic_2.total_messages == events_per_topic + assert topic_1.total_events == events_per_topic + assert topic_2.total_events == events_per_topic assert TopicManager.all_messages_consumed()