From db92d56fccd875eb144e8e50ec12ce8feea1b825 Mon Sep 17 00:00:00 2001 From: Santiago Fraire Willemoes Date: Fri, 1 Nov 2024 08:22:58 +0100 Subject: [PATCH] fix(streams_utils): properly identify if typed or not --- kstreams/streams_utils.py | 2 +- tests/test_streams.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/kstreams/streams_utils.py b/kstreams/streams_utils.py index d1c55133..ad36f226 100644 --- a/kstreams/streams_utils.py +++ b/kstreams/streams_utils.py @@ -60,7 +60,7 @@ async def consume(cr: ConsumerRecord, stream: Stream, send: Send): first_annotation = params[0].annotation - if first_annotation in (inspect._empty, Stream): + if first_annotation in (inspect._empty, Stream) and len(params) < 2: # use case 1 NO_TYPING return UDFType.NO_TYPING # typing cases diff --git a/tests/test_streams.py b/tests/test_streams.py index 93db6a8b..9f05a5ae 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -154,6 +154,45 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream): await stream.stop() +@pytest.mark.asyncio +async def test_stream_all_typing_order_in_setup_type( + stream_engine: StreamEngine, consumer_record_factory +): + topic_name = "local--kstreams" + value = b"test" + + async def getone(_): + return consumer_record_factory(value=value) + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + subscribe=mock.DEFAULT, + getone=getone, + ): + + @stream_engine.stream(topic_name) + async def stream(stream: Stream, cr: ConsumerRecord, send: Send): + assert cr.value == value + assert isinstance(stream, Stream) + assert send == stream_engine.send + await asyncio.sleep(0.2) + + assert stream.consumer is None + assert stream.topics == [topic_name] + + with contextlib.suppress(TimeoutErrorException): + # now it is possible to run a stream directly, so we need + # to stop the `forever` consumption + await asyncio.wait_for(stream.start(), timeout=0.1) + + assert stream.consumer + Consumer.subscribe.assert_called_once_with( + topics=[topic_name], listener=stream.rebalance_listener, pattern=None + ) + await stream.stop() + + @pytest.mark.asyncio async def test_stream_multiple_topics(stream_engine: StreamEngine): topics = ["local--hello-kpn", "local--hello-kpn-2"]