Skip to content

Commit

Permalink
feat(TestStreamClient): support custom classes for test consumers and…
Browse files Browse the repository at this point in the history
… producers
  • Loading branch information
JeroennC committed Jan 29, 2024
1 parent 2898233 commit 0ea7aba
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 12 additions & 8 deletions kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Type

from kstreams import ConsumerRecord
from kstreams import Consumer, ConsumerRecord, Producer
from kstreams.engine import StreamEngine
from kstreams.prometheus.monitor import PrometheusMonitor
from kstreams.serializers import Serializer
Expand Down Expand Up @@ -44,28 +44,32 @@ def __init__(
stream_engine: StreamEngine,
monitoring_enabled: bool = True,
topics: Optional[List[str]] = None,
test_producer_class: Type[Producer] = TestProducer,
test_consumer_class: Type[Consumer] = TestConsumer,
) -> None:
self.stream_engine = stream_engine
self.test_producer_class = test_producer_class
self.test_consumer_class = test_consumer_class

# Extra topics' names defined by the end user which must be created
# before the cycle test starts
self.extra_user_topics = topics

# store the user clients to restore them later
self.monitor = stream_engine.monitor
self.producer_class = self.stream_engine.producer_class
self.consumer_class = self.stream_engine.consumer_class
self.engine_producer_class = self.stream_engine.producer_class
self.engine_consumer_class = self.stream_engine.consumer_class

self.stream_engine.producer_class = TestProducer
self.stream_engine.consumer_class = TestConsumer
self.stream_engine.producer_class = self.test_producer_class
self.stream_engine.consumer_class = self.test_consumer_class

if not monitoring_enabled:
self.stream_engine.monitor = TestMonitor()

def mock_streams(self) -> None:
streams: List[Stream] = self.stream_engine._streams
for stream in streams:
stream.consumer_class = TestConsumer
stream.consumer_class = self.test_consumer_class

def setup_mocks(self) -> None:
self.mock_streams()
Expand All @@ -87,8 +91,8 @@ async def stop(self) -> None:
await self.stream_engine.stop()

# restore original config
self.stream_engine.producer_class = self.producer_class
self.stream_engine.consumer_class = self.consumer_class
self.stream_engine.producer_class = self.engine_producer_class
self.stream_engine.consumer_class = self.engine_consumer_class
self.stream_engine.monitor = self.monitor

# clean the topics after finishing the test to make sure that
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ async def test_engine_clients(stream_engine: StreamEngine):
assert stream_engine.producer_class is TestProducer

# after leaving the context, everything should go to normal
assert client.stream_engine.consumer_class is client.consumer_class
assert client.stream_engine.producer_class is client.producer_class
assert client.stream_engine.consumer_class is client.engine_consumer_class
assert client.stream_engine.producer_class is client.engine_producer_class


@pytest.mark.asyncio
Expand Down

0 comments on commit 0ea7aba

Please sign in to comment.