Skip to content

Commit

Permalink
fix: TestConsumer partition assigments, TestProducer consumer record …
Browse files Browse the repository at this point in the history
…and record metadata (#66)

Co-authored-by: Marcos Schroh <[email protected]>
  • Loading branch information
marcosschroh and marcosschroh authored Oct 11, 2022
1 parent 06b587f commit 9117c4b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from kstreams import ConsumerRecord, create_engine
from kstreams.streams import Stream

topic = "local--kstreams"
topic = "local--kstreams-test"

stream_engine = create_engine(title="my-stream-engine")

Expand Down
41 changes: 31 additions & 10 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ async def send(
) -> Coroutine:
topic = TopicManager.get_or_create(topic_name)
timestamp_ms = timestamp_ms or datetime.now().timestamp()
total_messages = topic.total_messages + 1
total_partition_events = (
topic.get_total_partition_events(partition=partition) + 1
)

consumer_record = ConsumerRecord(
topic=topic_name,
Expand All @@ -39,7 +41,7 @@ async def send(
headers=headers,
partition=partition,
timestamp=timestamp_ms,
offset=total_messages,
offset=total_partition_events,
timestamp_type=None,
checksum=None,
serialized_key_size=None,
Expand All @@ -53,7 +55,7 @@ async def fut():
topic=topic_name,
partition=partition,
timestamp=timestamp_ms,
offset=total_messages,
offset=total_partition_events,
)

return fut()
Expand All @@ -64,25 +66,42 @@ def __init__(self, *topics: str, group_id: Optional[str] = None, **kwargs) -> No
# copy the aiokafka behavior
self.topics: Tuple[str, ...] = topics
self._group_id: Optional[str] = group_id
self._assigments: List[TopicPartition] = []
self._assignment: List[TopicPartition] = []
self.partitions_committed: Dict[TopicPartition, int] = {}

for topic_name in topics:
TopicManager.create(topic_name, consumer=self)
self._assigments.append(TopicPartition(topic=topic_name, partition=1))
TopicManager.get_or_create(topic_name, consumer=self)
self._assignment.append(TopicPartition(topic=topic_name, partition=1))

# Called to make sure that has all the kafka attributes like _coordinator
# so it will behave like an real Kafka Consumer
super().__init__()

def assignment(self) -> List[TopicPartition]:
return self._assigments
return self._assignment

def _check_partition_assignments(self, consumer_record: ConsumerRecord) -> None:
"""
When an event is consumed the partition can be any positive int number
because there is not limit in the producer side (only during testing of course).
In case that the partition is not in the `_assignment` we need to register it.
This is only during testing as in real use cases the assignments happens
at the moment of kafka bootstrapping
"""
topic_partition = TopicPartition(
topic=consumer_record.topic,
partition=consumer_record.partition,
)

if topic_partition not in self._assignment:
self._assignment.append(topic_partition)

def last_stable_offset(self, topic_partition: TopicPartition) -> int:
topic = TopicManager.get(topic_partition.topic)

if topic is not None:
return topic.total_messages
return topic.get_total_partition_events(partition=topic_partition.partition)
return -1

async def position(self, topic_partition: TopicPartition) -> int:
Expand All @@ -104,12 +123,14 @@ async def getone(
self,
) -> Optional[ConsumerRecord]: # The return type must be fixed later on
topic = None
for topic_partition in self._assigments:
for topic_partition in self._assignment:
topic = TopicManager.get(topic_partition.topic)

if not topic.consumed:
break

if topic is not None:
return await topic.get()
consumer_record = await topic.get()
self._check_partition_assignments(consumer_record)
return consumer_record
return None
34 changes: 27 additions & 7 deletions kstreams/test_utils/topics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional
from collections import defaultdict
from dataclasses import dataclass, field
from typing import ClassVar, DefaultDict, Dict, Optional

from kstreams import ConsumerRecord

Expand All @@ -11,20 +12,32 @@
class Topic:
name: str
queue: asyncio.Queue
total_messages: int = 0
total_partition_events: DefaultDict[int, int] = field(
default_factory=lambda: defaultdict(int)
)
total_events: int = 0
# for now we assumed that 1 streams is connected to 1 topic
consumer: Optional["test_clients.Consumer"] = None

async def put(self, event: ConsumerRecord) -> None:
await self.queue.put(event)
self.total_messages += 1

# keep track of the amount of events per topic partition
self.total_partition_events[event.partition] += 1
self.total_events += 1

async def get(self) -> ConsumerRecord:
return await self.queue.get()

def is_empty(self) -> bool:
return self.queue.empty()

def size(self) -> int:
return self.queue.qsize()

def get_total_partition_events(self, *, partition: int) -> int:
return self.total_partition_events[partition]

@property
def consumed(self) -> bool:
"""
Expand Down Expand Up @@ -54,7 +67,12 @@ def get(cls, name: str) -> Topic:

if topic is not None:
return topic
raise ValueError(f"Topic {name} not found")
raise ValueError(
f"You might be trying to get the topic {name} outside the "
"`client async context` or trying to geh an event from an empty "
f"topic {name}. Make sure that the code is inside the async context"
"and the topic has events."
)

@classmethod
def create(
Expand All @@ -65,14 +83,16 @@ def create(
return topic

@classmethod
def get_or_create(cls, name: str) -> Topic:
def get_or_create(
cls, name: str, consumer: Optional["test_clients.Consumer"] = None
) -> Topic:
"""
Add a new queue if does not exist and return it
"""
try:
topic = cls.get(name)
except ValueError:
topic = cls.create(name)
topic = cls.create(name, consumer=consumer)
return topic

@classmethod
Expand Down
23 changes: 15 additions & 8 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,30 @@ async def test_send_event_with_test_client(stream_engine: StreamEngine):
metadata = await client.send(
topic, value=b'{"message": "Hello world!"}', key="1"
)
current_offset = metadata.offset

assert metadata.topic == topic
assert metadata.partition == 1
assert metadata.offset == 1

# send another event and check that the offset was incremented
metadata = await client.send(
topic, value=b'{"message": "Hello world!"}', key="1"
)
assert metadata.offset == current_offset + 1
assert metadata.offset == 2

# send en event to a different partition
metadata = await client.send(
topic, value=b'{"message": "Hello world!"}', key="1", partition=2
)

# because it is a different partition the offset should be 1
assert metadata.offset == 1


@pytest.mark.asyncio
async def test_streams_consume_events(stream_engine: StreamEngine):
from examples.simple import stream_engine

client = TestStreamClient(stream_engine)
topic = "local--kstreams-2"
topic = "local--kstreams-consumer"
event = b'{"message": "Hello world!"}'
tp = structs.TopicPartition(topic=topic, partition=1)
save_to_db = Mock()
Expand Down Expand Up @@ -91,7 +98,7 @@ async def test_topic_created(stream_engine: StreamEngine):

@pytest.mark.asyncio
async def test_consumer_commit(stream_engine: StreamEngine):
topic_name = "local--kstreams-marcos"
topic_name = "local--kstreams-consumer-commit"
value = b'{"message": "Hello world!"}'
name = "my-stream"
key = "1"
Expand Down Expand Up @@ -151,7 +158,7 @@ async def test_e2e_consume_multiple_topics():
topic_1 = TopicManager.get(topics[0])
topic_2 = TopicManager.get(topics[1])

assert topic_1.total_messages == events_per_topic
assert topic_2.total_messages == events_per_topic
assert topic_1.total_events == events_per_topic
assert topic_2.total_events == events_per_topic

assert TopicManager.all_messages_consumed()

0 comments on commit 9117c4b

Please sign in to comment.