Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/setup type #232

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kstreams/streams_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def consume(cr: ConsumerRecord, stream: Stream, send: Send):

first_annotation = params[0].annotation

if first_annotation in (inspect._empty, Stream):
if first_annotation in (inspect._empty, Stream) and len(params) < 2:
# use case 1 NO_TYPING
return UDFType.NO_TYPING
# typing cases
Expand Down
10 changes: 10 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def consume(stream):
async with client:
await client.send(topic, value=event, key="1")
stream = stream_engine.get_stream("my-stream")
assert stream is not None
assert stream.consumer is not None
assert stream.consumer.assignment() == [tp0, tp1, tp2]
assert stream.consumer.last_stable_offset(tp0) == 0
assert stream.consumer.highwater(tp0) == 1
Expand Down Expand Up @@ -164,6 +166,7 @@ async def my_stream(cr: ConsumerRecord, stream: Stream):

# give some time so the `commit` can finished
await asyncio.sleep(1)
assert my_stream.consumer is not None
assert await my_stream.consumer.committed(tp) == 100

# check that the event was consumed
Expand Down Expand Up @@ -381,6 +384,7 @@ async def stream(stream):
await client.send(topic_name, value=value, key=key, partition=10)

await asyncio.sleep(1e-10)
assert stream.consumer is not None
assert stream.consumer.partitions_for_topic(topic_name) == set([0, 1, 2, 10])


Expand Down Expand Up @@ -410,6 +414,8 @@ async def consume(stream):
]

stream = stream_engine.get_stream("my-stream")
assert stream is not None
assert stream.consumer is not None
assert (await stream.consumer.end_offsets(topic_partitions)) == {
TopicPartition(topic="local--kstreams", partition=0): 2,
TopicPartition(topic="local--kstreams", partition=2): 1,
Expand Down Expand Up @@ -445,6 +451,7 @@ async def my_stream(stream: Stream):

await asyncio.sleep(1e-10)
# check that everything was commited
assert my_stream.consumer is not None
assert (await my_stream.consumer.committed(tp)) == total_events - 1


Expand Down Expand Up @@ -609,6 +616,9 @@ async def func_stream(consumer: Stream):
# to stop the `forever` consumption
await asyncio.wait_for(stream.start(), timeout=1.0)

assert stream.consumer is not None
assert stream.rebalance_listener is not None

# simulate partitions assigned on rebalance
await stream.rebalance_listener.on_partitions_assigned(assigned=assignments)
assert stream.consumer.assignment() == [tp0, tp1, tp2]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def my_stream(stream: Stream):
assert rebalance_listener.stream == my_stream

# checking that the subscription has also the rebalance_listener
assert my_stream.consumer is not None
assert my_stream.consumer._subscription._listener == rebalance_listener

await stream_engine.stop()
Expand Down Expand Up @@ -172,6 +173,7 @@ async def hello_stream(stream: Stream):

assert isinstance(rebalance_listener, ManualCommitRebalanceListener)
# checking that the subscription has also the rebalance_listener
assert hello_stream.consumer is not None
assert isinstance(
hello_stream.consumer._subscription._listener, ManualCommitRebalanceListener
)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ async def my_coroutine(_):
stream_engine.add_stream(stream=stream)
await stream.start()

assert stream.consumer is not None

await stream_engine.monitor.generate_consumer_metrics(stream.consumer)
consumer = stream.consumer

Expand Down
29 changes: 14 additions & 15 deletions tests/test_stream_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from kstreams.exceptions import DuplicateStreamException, EngineNotStartedException


class DummyDeserializer:
async def deserialize(
self, consumer_record: ConsumerRecord, **kwargs
) -> ConsumerRecord:
return consumer_record


@pytest.mark.asyncio
async def test_add_streams(stream_engine: StreamEngine):
topic = "local--hello-kpn"
Expand Down Expand Up @@ -42,9 +49,7 @@ async def my_stream(_):
async def test_add_stream_as_instance(stream_engine: StreamEngine):
topics = ["local--hello-kpn", "local--hello-kpn-2"]

class MyDeserializer: ...

deserializer = MyDeserializer()
deserializer = DummyDeserializer()

async def processor(stream: Stream):
pass
Expand Down Expand Up @@ -73,9 +78,7 @@ async def processor(stream: Stream):
async def test_remove_existing_stream(stream_engine: StreamEngine):
topic = "local--hello-kpn"

class MyDeserializer: ...

deserializer = MyDeserializer()
deserializer = DummyDeserializer()

async def processor(stream: Stream):
pass
Expand All @@ -97,9 +100,7 @@ async def processor(stream: Stream):
async def test_remove_missing_stream(stream_engine: StreamEngine):
topic = "local--hello-kpn"

class MyDeserializer: ...

deserializer = MyDeserializer()
deserializer = DummyDeserializer()

async def processor(stream: Stream):
pass
Expand All @@ -119,9 +120,7 @@ async def processor(stream: Stream):
async def test_remove_existing_stream_stops_stream(stream_engine: StreamEngine):
topic = "local--hello-kpn"

class MyDeserializer: ...

deserializer = MyDeserializer()
deserializer = DummyDeserializer()

async def processor(stream: Stream):
pass
Expand All @@ -136,7 +135,7 @@ async def processor(stream: Stream):

with mock.patch.multiple(Stream, start=mock.DEFAULT, stop=mock.DEFAULT):
await stream_engine.remove_stream(my_stream)
Stream.stop.assert_awaited()
Stream.stop.assert_awaited() # type: ignore


@pytest.mark.asyncio
Expand All @@ -149,13 +148,13 @@ async def stream(_): ...
with mock.patch.multiple(Consumer, start=mock.DEFAULT, stop=mock.DEFAULT):
with mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT):
await stream_engine.start()
stream_engine._producer.start.assert_awaited()
stream_engine._producer.start.assert_awaited() # type: ignore

await asyncio.sleep(0) # Allow stream coroutine to run once
Consumer.start.assert_awaited()

await stream_engine.stop()
stream_engine._producer.stop.assert_awaited()
stream_engine._producer.stop.assert_awaited() # type: ignore
Consumer.stop.assert_awaited()


Expand Down
43 changes: 43 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,45 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_all_typing_order_in_setup_type(
stream_engine: StreamEngine, consumer_record_factory
):
topic_name = "local--kstreams"
value = b"test"

async def getone(_):
return consumer_record_factory(value=value)

with mock.patch.multiple(
Consumer,
start=mock.DEFAULT,
subscribe=mock.DEFAULT,
getone=getone,
):

@stream_engine.stream(topic_name)
async def stream(stream: Stream, cr: ConsumerRecord, send: Send):
assert cr.value == value
assert isinstance(stream, Stream)
assert send == stream_engine.send
await asyncio.sleep(0.2)

assert stream.consumer is None
assert stream.topics == [topic_name]

with contextlib.suppress(TimeoutErrorException):
# now it is possible to run a stream directly, so we need
# to stop the `forever` consumption
await asyncio.wait_for(stream.start(), timeout=0.1)

assert stream.consumer
Consumer.subscribe.assert_called_once_with(
topics=[topic_name], listener=stream.rebalance_listener, pattern=None
)
await stream.stop()


@pytest.mark.asyncio
async def test_stream_multiple_topics(stream_engine: StreamEngine):
topics = ["local--hello-kpn", "local--hello-kpn-2"]
Expand Down Expand Up @@ -241,6 +280,7 @@ async def stream(_): ...
# switch the current Task to the one running in background
await asyncio.sleep(0.1)

assert stream.consumer is not None
assert stream.consumer._auto_offset_reset == "earliest"
assert not stream.consumer._enable_auto_commit

Expand Down Expand Up @@ -291,6 +331,7 @@ async def streaming_fn(_):
await asyncio.sleep(0.1)

Consumer.start.assert_awaited()
assert stream_engine._producer is not None
stream_engine._producer.start.assert_awaited()

await stream_engine.stop()
Expand Down Expand Up @@ -377,6 +418,7 @@ async def stream(my_stream):

await stream.start()
# simulate a partitions assigned rebalance
assert stream.rebalance_listener is not None
await stream.rebalance_listener.on_partitions_assigned(assigned=assignments)

seek_mock.assert_called_once_with(
Expand Down Expand Up @@ -428,5 +470,6 @@ async def stream(my_stream):

await stream.start()
# simulate a partitions assigned rebalance
assert stream.rebalance_listener is not None
await stream.rebalance_listener.on_partitions_assigned(assigned=assignments)
seek_mock.assert_not_called()
Loading