Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add class based test sample #250

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from typing import Callable, Set
from unittest import mock

import pytest

from kstreams import ConsumerRecord, Send, TopicPartition
from kstreams.clients import Consumer, Producer
from kstreams.engine import Stream, StreamEngine
from kstreams.streams import stream
from kstreams.structs import TopicPartitionOffset
from kstreams.test_utils import TestStreamClient
from tests import TimeoutErrorException


# NOTE: remove the test when `no typing` support is deprecated
@pytest.mark.asyncio


async def test_stream_no_typing(stream_engine: StreamEngine, consumer_record_factory):
topic_name = "local--kstreams"
value = b"test"
Expand Down Expand Up @@ -50,7 +49,6 @@ async def stream(stream_instance):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -87,7 +85,6 @@ async def stream(cr: ConsumerRecord):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_generic_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -124,7 +121,43 @@ async def stream(cr: ConsumerRecord[str, bytes]):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_class_cr_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
topic_name = "local--kstreams"
target_topic = "local--kstreams-target"
value = "test"

client = TestStreamClient(stream_engine, topics=[target_topic])

class TestClass:
def __init__(self) -> None:
self.bar = value

async def streaming_fn(self, cr: ConsumerRecord, send: Send):
"""text from func"""

await send(target_topic, value=self.bar)

foo = TestClass()
_stream = Stream(
topics=[topic_name],
func=foo.streaming_fn,
)
stream_engine.add_stream(_stream)

async with client:
await stream_engine.start()
await client.send(topic_name, value=value)
# import ipdb; ipdb.set_trace()
client.get_topic(topic_name=target_topic)
r = await asyncio.wait_for(
client.get_event(topic_name=target_topic), timeout=0.2
)
# r = await client.get_event(topic_name=target_topic)
assert r.value == value


async def test_stream_cr_and_stream_with_typing(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -154,7 +187,6 @@ async def stream(cr: ConsumerRecord, stream: Stream):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_all_typing(stream_engine: StreamEngine, consumer_record_factory):
topic_name = "local--kstreams"
value = b"test"
Expand Down Expand Up @@ -191,7 +223,6 @@ async def stream(cr: ConsumerRecord, send: Send, stream: Stream):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_all_typing_order_in_setup_type(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -230,7 +261,6 @@ async def stream(stream: Stream, cr: ConsumerRecord, send: Send):
await stream.stop()


@pytest.mark.asyncio
async def test_stream_multiple_topics(stream_engine: StreamEngine):
topics = ["local--hello-kpn", "local--hello-kpn-2"]

Expand All @@ -251,7 +281,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_subscribe_topics_pattern(stream_engine: StreamEngine):
pattern = "^dev--customer-.*$"

Expand All @@ -273,7 +302,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_subscribe_topics_only_one_pattern(stream_engine: StreamEngine):
"""
We can use only one pattern, so we use the first one
Expand All @@ -299,7 +327,6 @@ async def stream(_): ...
)


@pytest.mark.asyncio
async def test_stream_custom_conf(stream_engine: StreamEngine):
@stream_engine.stream(
"local--hello-kpn",
Expand All @@ -323,7 +350,6 @@ async def stream(_): ...
assert not stream.consumer._enable_auto_commit


@pytest.mark.asyncio
async def test_stream_getmany(
stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord]
):
Expand Down Expand Up @@ -351,7 +377,6 @@ async def getmany(*args, **kwargs):
save_to_db.assert_called_once_with(topic_partition_crs)


@pytest.mark.asyncio
async def test_stream_decorator(stream_engine: StreamEngine):
topic = "local--hello-kpn"

Expand All @@ -377,7 +402,6 @@ async def streaming_fn(_):
Consumer.stop.assert_awaited()


@pytest.mark.asyncio
async def test_stream_decorates_properly(stream_engine: StreamEngine):
topic = "local--hello-kpn"

Expand All @@ -389,7 +413,6 @@ async def streaming_fn(_):
assert streaming_fn.__doc__ == "text from func"


@pytest.mark.asyncio
async def test_recreate_consumer_on_re_start_stream(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -418,7 +441,6 @@ async def stream(my_stream):
assert consumer is not stream.consumer


@pytest.mark.asyncio
async def test_seek_to_initial_offsets_normal(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -465,7 +487,6 @@ async def stream(my_stream):
)


@pytest.mark.asyncio
async def test_seek_to_initial_offsets_ignores_wrong_input(
stream_engine: StreamEngine, consumer_record_factory
):
Expand Down Expand Up @@ -511,3 +532,18 @@ async def stream(my_stream):
assert stream.rebalance_listener is not None
await stream.rebalance_listener.on_partitions_assigned(assigned=assignments)
seek_mock.assert_not_called()


async def test_stream_simple_di_works(
stream_engine: StreamEngine, consumer_record_factory
):
topic = "local--hello-kpn"
cr: ConsumerRecord = consumer_record_factory(topic=topic, value=b"test")

@stream_engine.stream(topic)
async def streaming_fn(cr: ConsumerRecord):
"""text from func"""
return cr.value

r = await streaming_fn.func(cr)
assert r == b"test"
Loading