Skip to content

Commit

Permalink
test: add class based test sample
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Nov 28, 2024
1 parent 72de291 commit 7b96622
Showing 1 changed file with 55 additions and 19 deletions.
74 changes: 55 additions & 19 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from typing import Callable, Set
from unittest import mock

import pytest

from kstreams import ConsumerRecord, Send, TopicPartition
from kstreams.clients import Consumer, Producer
from kstreams.engine import Stream, StreamEngine
from kstreams.streams import stream
from kstreams.structs import TopicPartitionOffset
from kstreams.test_utils import TestStreamClient
from tests import TimeoutErrorException


# NOTE: remove the test when `no typing` support is deprecated
@pytest.mark.asyncio


async def test_stream_no_typing(stream_engine: StreamEngine, consumer_record_factory):
topic_name = "local--kstreams"
value = b"test"
Expand Down Expand Up @@ -50,7 +49,6 @@ async def stream(stream_instance):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -87,7 +85,6 @@ async def stream(cr: ConsumerRecord):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_generic_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -124,7 +121,43 @@ async def stream(cr: ConsumerRecord[str, bytes]):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_class_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
topic_name = "local--kstreams"
target_topic = "local--kstreams-target"
value = "test"

client = TestStreamClient(stream_engine, topics=[target_topic])

class TestClass:
def __init__(self) -> None:
self.bar = value

async def streaming_fn(self, cr: ConsumerRecord, send: Send):
"""text from func"""

await send(target_topic, value=self.bar)

foo = TestClass()
_stream = Stream(
topics=[topic_name],
func=foo.streaming_fn,
)
stream_engine.add_stream(_stream)

async with client:
await stream_engine.start()
await client.send(topic_name, value=value)
# import ipdb; ipdb.set_trace()
client.get_topic(topic_name=target_topic)
r = await asyncio.wait_for(
client.get_event(topic_name=target_topic), timeout=0.2
)
# r = await client.get_event(topic_name=target_topic)
assert r.value == value


async def test_stream_cr_and_stream_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -154,7 +187,6 @@ async def stream(cr: ConsumerRecord, stream: Stream):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_all_typing(stream_engine: StreamEngine, consumer_record_factory):
topic_name = "local--kstreams"
value = b"test"
Expand Down Expand Up @@ -191,7 +223,6 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_all_typing_order_in_setup_type(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -230,7 +261,6 @@ async def stream(stream: Stream, cr: ConsumerRecord, send: Send):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_multiple_topics(stream_engine: StreamEngine):
topics = ["local--hello-kpn", "local--hello-kpn-2"]

Expand All @@ -251,7 +281,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_subscribe_topics_pattern(stream_engine: StreamEngine):
pattern = "^dev--customer-.*$"

Expand All @@ -273,7 +302,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_subscribe_topics_only_one_pattern(stream_engine: StreamEngine):
"""
We can use only one pattern, so we use the first one
Expand All @@ -299,7 +327,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_custom_conf(stream_engine: StreamEngine):
@stream_engine.stream(
"local--hello-kpn",
Expand All @@ -323,7 +350,6 @@ async def stream(_): ...
assert not stream.consumer._enable_auto_commit


@pytest.mark.asyncio
async def test_stream_getmany(
stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord]
):
Expand Down Expand Up @@ -351,7 +377,6 @@ async def getmany(*args, **kwargs):
save_to_db.assert_called_once_with(topic_partition_crs)


@pytest.mark.asyncio
async def test_stream_decorator(stream_engine: StreamEngine):
topic = "local--hello-kpn"

Expand All @@ -377,7 +402,6 @@ async def streaming_fn(_):
Consumer.stop.assert_awaited()


@pytest.mark.asyncio
async def test_stream_decorates_properly(stream_engine: StreamEngine):
topic = "local--hello-kpn"

Expand All @@ -389,7 +413,6 @@ async def streaming_fn(_):
assert streaming_fn.__doc__ == "text from func"


@pytest.mark.asyncio
async def test_recreate_consumer_on_re_start_stream(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -418,7 +441,6 @@ async def stream(my_stream):
assert consumer is not stream.consumer


@pytest.mark.asyncio
async def test_seek_to_initial_offsets_normal(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -465,7 +487,6 @@ async def stream(my_stream):
)


@pytest.mark.asyncio
async def test_seek_to_initial_offsets_ignores_wrong_input(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -511,3 +532,18 @@ async def stream(my_stream):
assert stream.rebalance_listener is not None
await stream.rebalance_listener.on_partitions_assigned(assigned=assignments)
seek_mock.assert_not_called()


async def test_stream_simple_di_works(
stream_engine: StreamEngine, consumer_record_factory
):
topic = "local--hello-kpn"
cr: ConsumerRecord = consumer_record_factory(topic=topic, value=b"test")

@stream_engine.stream(topic)
async def streaming_fn(cr: ConsumerRecord):
"""text from func"""
return cr.value

r = await streaming_fn.func(cr)
assert r == b"test"

0 comments on commit 7b96622

Please sign in to comment.