diff --git a/kstreams/engine.py b/kstreams/engine.py index 0144f792..3d4affee 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -10,7 +10,7 @@ from .backends.kafka import Kafka from .clients import Consumer, Producer from .exceptions import DuplicateStreamException, EngineNotStartedException -from .middleware import ExceptionMiddleware, Middleware +from .middleware import Middleware from .middleware.udf_middleware import UdfHandler from .prometheus.monitor import PrometheusMonitor from .rebalance_listener import MetricsRebalanceListener, RebalanceListener @@ -343,10 +343,39 @@ def get_stream(self, name: str) -> typing.Optional[Stream]: return stream def add_stream( - self, stream: Stream, error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP + self, stream: Stream, error_policy: typing.Optional[StreamErrorPolicy] = None ) -> None: + """ + Add a stream to the engine. + + This method registers a new stream with the engine, setting up necessary + configurations and handlers. If a stream with the same name already exists, + a DuplicateStreamException is raised. + + Args: + stream: The stream to be added. + error_policy: An optional error policy to be applied to the stream. + You should probably set directly when instanciating a Stream, not here. + + Raises: + DuplicateStreamException: If a stream with the same name already exists. + + Notes: + - If the stream does not have a deserializer, the engine's deserializer + is assigned to it. + - If the stream does not have a rebalance listener, a default + MetricsRebalanceListener is assigned. + - The stream's UDF handler is set up with the provided function and + engine's send method. + - If the stream's UDF handler type is not NO_TYPING, a middleware stack + is built for the stream's function. + """ if self.exist_stream(stream.name): raise DuplicateStreamException(name=stream.name) + + if error_policy is not None: + stream.error_policy = error_policy + stream.backend = self.backend if stream.deserializer is None: stream.deserializer = self.deserializer @@ -357,8 +386,8 @@ def add_stream( # when the callbacks are called stream.rebalance_listener = MetricsRebalanceListener() - stream.rebalance_listener.stream = stream # type: ignore - stream.rebalance_listener.engine = self # type: ignore + stream.rebalance_listener.stream = stream + stream.rebalance_listener.engine = self stream.udf_handler = UdfHandler( next_call=stream.func, @@ -369,21 +398,14 @@ def add_stream( # NOTE: When `no typing` support is deprecated this check can # be removed if stream.udf_handler.type != UDFType.NO_TYPING: - stream.func = self.build_stream_middleware_stack( - stream=stream, error_policy=error_policy - ) + stream.func = self._build_stream_middleware_stack(stream=stream) - def build_stream_middleware_stack( - self, *, stream: Stream, error_policy: StreamErrorPolicy - ) -> NextMiddlewareCall: + def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall: assert stream.udf_handler, "UdfHandler can not be None" - stream.middlewares = [ - Middleware(ExceptionMiddleware, engine=self, error_policy=error_policy), - ] + stream.middlewares - + middlewares = stream.get_middlewares(self) next_call = stream.udf_handler - for middleware, options in reversed(stream.middlewares): + for middleware, options in reversed(middlewares): next_call = middleware( next_call=next_call, send=self.send, stream=stream, **options ) diff --git a/kstreams/streams.py b/kstreams/streams.py index f6ff9f61..11c06ae0 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -1,4 +1,5 @@ import asyncio +import collections import inspect import logging import typing @@ -9,6 +10,7 @@ from kstreams import ConsumerRecord, TopicPartition from kstreams.exceptions import BackendNotSet +from kstreams.middleware.middleware import ExceptionMiddleware from kstreams.structs import TopicPartitionOffset from .backends.kafka import Kafka @@ -16,9 +18,12 @@ from .middleware import Middleware, udf_middleware from .rebalance_listener import RebalanceListener from .serializers import Deserializer -from .streams_utils import UDFType +from .streams_utils import StreamErrorPolicy, UDFType from .types import StreamFunc +if typing.TYPE_CHECKING: + from kstreams import StreamEngine + logger = logging.getLogger(__name__) @@ -152,6 +157,7 @@ def __init__( initial_offsets: typing.Optional[typing.List[TopicPartitionOffset]] = None, rebalance_listener: typing.Optional[RebalanceListener] = None, middlewares: typing.Optional[typing.List[Middleware]] = None, + error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP, ) -> None: self.func = func self.backend = backend @@ -169,6 +175,7 @@ def __init__( self.udf_handler: typing.Optional[udf_middleware.UdfHandler] = None self.topics = [topics] if isinstance(topics, str) else topics self.subscribe_by_pattern = subscribe_by_pattern + self.error_policy = error_policy def _create_consumer(self) -> Consumer: if self.backend is None: @@ -176,6 +183,28 @@ def _create_consumer(self) -> Consumer: config = {**self.backend.model_dump(), **self.config} return self.consumer_class(**config) + def get_middlewares( + self, engine: "StreamEngine" + ) -> collections.abc.Sequence[Middleware]: + """ + Retrieve the list of middlewares for the stream engine. + + Use this instead of the `middlewares` attribute to get the list of middlewares. + + Args: + engine: The stream engine instance. + + Returns: + A sequence of Middleware instances. + Including the ExceptionMiddleware with the specified error policy and any + additional middlewares. + """ + return [ + Middleware( + ExceptionMiddleware, engine=engine, error_policy=self.error_policy + ) + ] + self.middlewares + async def stop(self) -> None: if self.running: # Don't run anymore to prevent new events comming diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py index c7198c31..c6ede1b6 100644 --- a/tests/middleware/test_middleware.py +++ b/tests/middleware/test_middleware.py @@ -30,13 +30,19 @@ async def process(cr: ConsumerRecord, stream: Stream): ... my_stream = stream_engine.get_stream(stream_name) - my_stream_local = stream_engine.get_stream(stream_name) + if my_stream is None: + raise ValueError("Stream not found") + my_stream_local = stream_engine.get_stream(stream_name_local) + if my_stream_local is None: + raise ValueError("Stream not found") + middlewares = [ - middleware_factory.middleware for middleware_factory in my_stream.middlewares + middleware_factory.middleware + for middleware_factory in my_stream.get_middlewares(stream_engine) ] middlewares_stream_local = [ middleware_factory.middleware - for middleware_factory in my_stream_local.middlewares + for middleware_factory in my_stream_local.get_middlewares(stream_engine) ] assert ( middlewares @@ -63,8 +69,11 @@ async def consume(cr: ConsumerRecord): ... my_stream = stream_engine.get_stream(stream_name) + if my_stream is None: + raise ValueError("Stream not found") middlewares = [ - middleware_factory.middleware for middleware_factory in my_stream.middlewares + middleware_factory.middleware + for middleware_factory in my_stream.get_middlewares(stream_engine) ] assert middlewares == [ middleware.ExceptionMiddleware, diff --git a/tests/test_streams_error_policy.py b/tests/test_streams_error_policy.py index 83632d71..276598ac 100644 --- a/tests/test_streams_error_policy.py +++ b/tests/test_streams_error_policy.py @@ -93,6 +93,7 @@ async def test_stop_application_error_policy(stream_engine: StreamEngine): client = TestStreamClient(stream_engine) with mock.patch("signal.raise_signal"): + @stream_engine.stream(topic, error_policy=StreamErrorPolicy.STOP_APPLICATION) async def my_stream(cr: ConsumerRecord): raise ValueError("Crashing Stream...")