diff --git a/kstreams/streams_utils.py b/kstreams/streams_utils.py index d1c55133..ad36f226 100644 --- a/kstreams/streams_utils.py +++ b/kstreams/streams_utils.py @@ -60,7 +60,7 @@ async def consume(cr: ConsumerRecord, stream: Stream, send: Send): first_annotation = params[0].annotation - if first_annotation in (inspect._empty, Stream): + if first_annotation in (inspect._empty, Stream) and len(params) < 2: # use case 1 NO_TYPING return UDFType.NO_TYPING # typing cases diff --git a/tests/test_client.py b/tests/test_client.py index b891829e..eb0370ca 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -87,6 +87,8 @@ async def consume(stream): async with client: await client.send(topic, value=event, key="1") stream = stream_engine.get_stream("my-stream") + assert stream is not None + assert stream.consumer is not None assert stream.consumer.assignment() == [tp0, tp1, tp2] assert stream.consumer.last_stable_offset(tp0) == 0 assert stream.consumer.highwater(tp0) == 1 @@ -164,6 +166,7 @@ async def my_stream(cr: ConsumerRecord, stream: Stream): # give some time so the `commit` can finished await asyncio.sleep(1) + assert my_stream.consumer is not None assert await my_stream.consumer.committed(tp) == 100 # check that the event was consumed @@ -381,6 +384,7 @@ async def stream(stream): await client.send(topic_name, value=value, key=key, partition=10) await asyncio.sleep(1e-10) + assert stream.consumer is not None assert stream.consumer.partitions_for_topic(topic_name) == set([0, 1, 2, 10]) @@ -410,6 +414,8 @@ async def consume(stream): ] stream = stream_engine.get_stream("my-stream") + assert stream is not None + assert stream.consumer is not None assert (await stream.consumer.end_offsets(topic_partitions)) == { TopicPartition(topic="local--kstreams", partition=0): 2, TopicPartition(topic="local--kstreams", partition=2): 1, @@ -445,6 +451,7 @@ async def my_stream(stream: Stream): await asyncio.sleep(1e-10) # check that everything was commited + assert my_stream.consumer is not None assert (await my_stream.consumer.committed(tp)) == total_events - 1 @@ -609,6 +616,9 @@ async def func_stream(consumer: Stream): # to stop the `forever` consumption await asyncio.wait_for(stream.start(), timeout=1.0) + assert stream.consumer is not None + assert stream.rebalance_listener is not None + # simulate partitions assigned on rebalance await stream.rebalance_listener.on_partitions_assigned(assigned=assignments) assert stream.consumer.assignment() == [tp0, tp1, tp2] diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 7ceeccbc..149927bb 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -91,6 +91,7 @@ async def my_stream(stream: Stream): assert rebalance_listener.stream == my_stream # checking that the subscription has also the rebalance_listener + assert my_stream.consumer is not None assert my_stream.consumer._subscription._listener == rebalance_listener await stream_engine.stop() @@ -172,6 +173,7 @@ async def hello_stream(stream: Stream): assert isinstance(rebalance_listener, ManualCommitRebalanceListener) # checking that the subscription has also the rebalance_listener + assert hello_stream.consumer is not None assert isinstance( hello_stream.consumer._subscription._listener, ManualCommitRebalanceListener ) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 270b41a2..0b94361b 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -127,6 +127,8 @@ async def my_coroutine(_): stream_engine.add_stream(stream=stream) await stream.start() + assert stream.consumer is not None + await stream_engine.monitor.generate_consumer_metrics(stream.consumer) consumer = stream.consumer diff --git a/tests/test_stream_engine.py b/tests/test_stream_engine.py index bb1e4c5e..3e5f1f8f 100644 --- a/tests/test_stream_engine.py +++ b/tests/test_stream_engine.py @@ -10,6 +10,13 @@ from kstreams.exceptions import DuplicateStreamException, EngineNotStartedException +class DummyDeserializer: + async def deserialize( + self, consumer_record: ConsumerRecord, **kwargs + ) -> ConsumerRecord: + return consumer_record + + @pytest.mark.asyncio async def test_add_streams(stream_engine: StreamEngine): topic = "local--hello-kpn" @@ -42,9 +49,7 @@ async def my_stream(_): async def test_add_stream_as_instance(stream_engine: StreamEngine): topics = ["local--hello-kpn", "local--hello-kpn-2"] - class MyDeserializer: ... - - deserializer = MyDeserializer() + deserializer = DummyDeserializer() async def processor(stream: Stream): pass @@ -73,9 +78,7 @@ async def processor(stream: Stream): async def test_remove_existing_stream(stream_engine: StreamEngine): topic = "local--hello-kpn" - class MyDeserializer: ... - - deserializer = MyDeserializer() + deserializer = DummyDeserializer() async def processor(stream: Stream): pass @@ -97,9 +100,7 @@ async def processor(stream: Stream): async def test_remove_missing_stream(stream_engine: StreamEngine): topic = "local--hello-kpn" - class MyDeserializer: ... - - deserializer = MyDeserializer() + deserializer = DummyDeserializer() async def processor(stream: Stream): pass @@ -119,9 +120,7 @@ async def processor(stream: Stream): async def test_remove_existing_stream_stops_stream(stream_engine: StreamEngine): topic = "local--hello-kpn" - class MyDeserializer: ... - - deserializer = MyDeserializer() + deserializer = DummyDeserializer() async def processor(stream: Stream): pass @@ -136,7 +135,7 @@ async def processor(stream: Stream): with mock.patch.multiple(Stream, start=mock.DEFAULT, stop=mock.DEFAULT): await stream_engine.remove_stream(my_stream) - Stream.stop.assert_awaited() + Stream.stop.assert_awaited() # type: ignore @pytest.mark.asyncio @@ -149,13 +148,13 @@ async def stream(_): ... with mock.patch.multiple(Consumer, start=mock.DEFAULT, stop=mock.DEFAULT): with mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): await stream_engine.start() - stream_engine._producer.start.assert_awaited() + stream_engine._producer.start.assert_awaited() # type: ignore await asyncio.sleep(0) # Allow stream coroutine to run once Consumer.start.assert_awaited() await stream_engine.stop() - stream_engine._producer.stop.assert_awaited() + stream_engine._producer.stop.assert_awaited() # type: ignore Consumer.stop.assert_awaited() diff --git a/tests/test_streams.py b/tests/test_streams.py index b07cc377..9f05a5ae 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -154,6 +154,45 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream): await stream.stop() +@pytest.mark.asyncio +async def test_stream_all_typing_order_in_setup_type( + stream_engine: StreamEngine, consumer_record_factory +): + topic_name = "local--kstreams" + value = b"test" + + async def getone(_): + return consumer_record_factory(value=value) + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + subscribe=mock.DEFAULT, + getone=getone, + ): + + @stream_engine.stream(topic_name) + async def stream(stream: Stream, cr: ConsumerRecord, send: Send): + assert cr.value == value + assert isinstance(stream, Stream) + assert send == stream_engine.send + await asyncio.sleep(0.2) + + assert stream.consumer is None + assert stream.topics == [topic_name] + + with contextlib.suppress(TimeoutErrorException): + # now it is possible to run a stream directly, so we need + # to stop the `forever` consumption + await asyncio.wait_for(stream.start(), timeout=0.1) + + assert stream.consumer + Consumer.subscribe.assert_called_once_with( + topics=[topic_name], listener=stream.rebalance_listener, pattern=None + ) + await stream.stop() + + @pytest.mark.asyncio async def test_stream_multiple_topics(stream_engine: StreamEngine): topics = ["local--hello-kpn", "local--hello-kpn-2"] @@ -241,6 +280,7 @@ async def stream(_): ... # switch the current Task to the one running in background await asyncio.sleep(0.1) + assert stream.consumer is not None assert stream.consumer._auto_offset_reset == "earliest" assert not stream.consumer._enable_auto_commit @@ -291,6 +331,7 @@ async def streaming_fn(_): await asyncio.sleep(0.1) Consumer.start.assert_awaited() + assert stream_engine._producer is not None stream_engine._producer.start.assert_awaited() await stream_engine.stop() @@ -377,6 +418,7 @@ async def stream(my_stream): await stream.start() # simulate a partitions assigned rebalance + assert stream.rebalance_listener is not None await stream.rebalance_listener.on_partitions_assigned(assigned=assignments) seek_mock.assert_called_once_with( @@ -428,5 +470,6 @@ async def stream(my_stream): await stream.start() # simulate a partitions assigned rebalance + assert stream.rebalance_listener is not None await stream.rebalance_listener.on_partitions_assigned(assigned=assignments) seek_mock.assert_not_called()