diff --git a/kstreams/test_utils/test_utils.py b/kstreams/test_utils/test_utils.py index f5fde440..9059514c 100644 --- a/kstreams/test_utils/test_utils.py +++ b/kstreams/test_utils/test_utils.py @@ -1,7 +1,7 @@ from types import TracebackType from typing import Any, Dict, List, Optional, Type -from kstreams import ConsumerRecord +from kstreams import Consumer, ConsumerRecord, Producer from kstreams.engine import StreamEngine from kstreams.prometheus.monitor import PrometheusMonitor from kstreams.serializers import Serializer @@ -44,8 +44,12 @@ def __init__( stream_engine: StreamEngine, monitoring_enabled: bool = True, topics: Optional[List[str]] = None, + test_producer_class: Type[Producer] = TestProducer, + test_consumer_class: Type[Consumer] = TestConsumer, ) -> None: self.stream_engine = stream_engine + self.test_producer_class = test_producer_class + self.test_consumer_class = test_consumer_class # Extra topics' names defined by the end user which must be created # before the cycle test starts @@ -53,11 +57,11 @@ def __init__( # store the user clients to restore them later self.monitor = stream_engine.monitor - self.producer_class = self.stream_engine.producer_class - self.consumer_class = self.stream_engine.consumer_class + self.engine_producer_class = self.stream_engine.producer_class + self.engine_consumer_class = self.stream_engine.consumer_class - self.stream_engine.producer_class = TestProducer - self.stream_engine.consumer_class = TestConsumer + self.stream_engine.producer_class = self.test_producer_class + self.stream_engine.consumer_class = self.test_consumer_class if not monitoring_enabled: self.stream_engine.monitor = TestMonitor() @@ -65,7 +69,7 @@ def __init__( def mock_streams(self) -> None: streams: List[Stream] = self.stream_engine._streams for stream in streams: - stream.consumer_class = TestConsumer + stream.consumer_class = self.test_consumer_class def setup_mocks(self) -> None: self.mock_streams() @@ -87,8 +91,8 @@ async def stop(self) -> None: await self.stream_engine.stop() # restore original config - self.stream_engine.producer_class = self.producer_class - self.stream_engine.consumer_class = self.consumer_class + self.stream_engine.producer_class = self.engine_producer_class + self.stream_engine.consumer_class = self.engine_consumer_class self.stream_engine.monitor = self.monitor # clean the topics after finishing the test to make sure that diff --git a/tests/test_client.py b/tests/test_client.py index e49c3363..c33f3d1f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -29,8 +29,8 @@ async def test_engine_clients(stream_engine: StreamEngine): assert stream_engine.producer_class is TestProducer # after leaving the context, everything should go to normal - assert client.stream_engine.consumer_class is client.consumer_class - assert client.stream_engine.producer_class is client.producer_class + assert client.stream_engine.consumer_class is client.engine_consumer_class + assert client.stream_engine.producer_class is client.engine_producer_class @pytest.mark.asyncio