diff --git a/README.md b/README.md index 94d77187..b0749031 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ if __name__ == "__main__": - [x] Yield events from streams - [x] [Opentelemetry Instrumentation](https://github.com/kpn/opentelemetry-instrumentation-kstreams) - [x] Middlewares +- [x] Hooks (on_startup, on_stop, after_startup, after_stop) - [ ] Store (kafka streams pattern) - [ ] Stream Join - [ ] Windowing diff --git a/docs/engine.md b/docs/engine.md index cf5f2b58..839ac39c 100644 --- a/docs/engine.md +++ b/docs/engine.md @@ -5,3 +5,4 @@ show_root_heading: true docstring_section_style: table show_signature_annotations: false + members_order: source diff --git a/kstreams/create.py b/kstreams/create.py index 6ee24aa4..51b01048 100644 --- a/kstreams/create.py +++ b/kstreams/create.py @@ -5,6 +5,7 @@ from .engine import StreamEngine from .prometheus.monitor import PrometheusMonitor from .serializers import Deserializer, Serializer +from .types import EngineHooks def create_engine( @@ -15,6 +16,10 @@ def create_engine( serializer: Optional[Serializer] = None, deserializer: Optional[Deserializer] = None, monitor: Optional[PrometheusMonitor] = None, + on_startup: Optional[EngineHooks] = None, + on_stop: Optional[EngineHooks] = None, + after_startup: Optional[EngineHooks] = None, + after_stop: Optional[EngineHooks] = None, ) -> StreamEngine: if monitor is None: monitor = PrometheusMonitor() @@ -30,4 +35,8 @@ def create_engine( serializer=serializer, deserializer=deserializer, monitor=monitor, + on_startup=on_startup, + on_stop=on_stop, + after_startup=after_startup, + after_stop=after_startup, ) diff --git a/kstreams/engine.py b/kstreams/engine.py index 0d1d2f74..d59a5371 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -17,8 +17,8 @@ from .streams import Stream, StreamFunc from .streams import stream as stream_func from .streams_utils import UDFType -from .types import Headers, NextMiddlewareCall -from .utils import encode_headers +from .types import EngineHooks, Headers, NextMiddlewareCall +from .utils import encode_headers, execute_hooks logger = logging.getLogger(__name__) @@ -68,6 +68,10 @@ def __init__( title: typing.Optional[str] = None, deserializer: typing.Optional[Deserializer] = None, serializer: typing.Optional[Serializer] = None, + on_startup: typing.Optional[EngineHooks] = None, + on_stop: typing.Optional[EngineHooks] = None, + after_startup: typing.Optional[EngineHooks] = None, + after_stop: typing.Optional[EngineHooks] = None, ) -> None: self.title = title self.backend = backend @@ -78,6 +82,10 @@ def __init__( self.monitor = monitor self._producer: typing.Optional[typing.Type[Producer]] = None self._streams: typing.List[Stream] = [] + self._on_startup = [] if on_startup is None else list(on_startup) + self._on_stop = [] if on_stop is None else list(on_stop) + self._after_startup = [] if after_startup is None else list(after_startup) + self._after_stop = [] if after_stop is None else list(after_stop) async def send( self, @@ -133,6 +141,9 @@ async def send( return metadata async def start(self) -> None: + # Execute on_startup hooks + await execute_hooks(self._on_startup) + # add the producer and streams to the Monitor self.monitor.add_producer(self._producer) self.monitor.add_streams(self._streams) @@ -140,11 +151,144 @@ async def start(self) -> None: await self.start_producer() await self.start_streams() + # Execute after_startup hooks + await execute_hooks(self._after_startup) + + def on_startup( + self, + func: typing.Callable[[], typing.Any], + ) -> typing.Callable[[], typing.Any]: + """ + A list of callables to run before the engine starts. + Handler are callables that do not take any arguments, and may be either + standard functions, or async functions. + + Attributes: + func typing.Callable[[], typing.Any]: Func to callable before engine starts + + !!! Example + ```python title="Engine before startup" + + import kstreams + + stream_engine = kstreams.create_engine( + title="my-stream-engine" + ) + + @stream_engine.on_startup + async def init_db() -> None: + print("Initializing Database Connections") + await init_db() + + + @stream_engine.on_startup + async def start_background_task() -> None: + print("Some background task") + ``` + """ + self._on_startup.append(func) + return func + + def on_stop( + self, + func: typing.Callable[[], typing.Any], + ) -> typing.Callable[[], typing.Any]: + """ + A list of callables to run before the engine stops. + Handler are callables that do not take any arguments, and may be either + standard functions, or async functions. + + Attributes: + func typing.Callable[[], typing.Any]: Func to callable before engine stops + + !!! Example + ```python title="Engine before stops" + + import kstreams + + stream_engine = kstreams.create_engine( + title="my-stream-engine" + ) + + @stream_engine.on_stop + async def close_db() -> None: + print("Closing Database Connections") + await db_close() + ``` + """ + self._on_stop.append(func) + return func + + def after_startup( + self, + func: typing.Callable[[], typing.Any], + ) -> typing.Callable[[], typing.Any]: + """ + A list of callables to run after the engine starts. + Handler are callables that do not take any arguments, and may be either + standard functions, or async functions. + + Attributes: + func typing.Callable[[], typing.Any]: Func to callable after engine starts + + !!! Example + ```python title="Engine after startup" + + import kstreams + + stream_engine = kstreams.create_engine( + title="my-stream-engine" + ) + + @stream_engine.after_startup + async def after_startup() -> None: + print("Set pod as healthy") + await mark_healthy_pod() + ``` + """ + self._after_startup.append(func) + return func + + def after_stop( + self, + func: typing.Callable[[], typing.Any], + ) -> typing.Callable[[], typing.Any]: + """ + A list of callables to run after the engine stops. + Handler are callables that do not take any arguments, and may be either + standard functions, or async functions. + + Attributes: + func typing.Callable[[], typing.Any]: Func to callable after engine stops + + !!! Example + ```python title="Engine after stops" + + import kstreams + + stream_engine = kstreams.create_engine( + title="my-stream-engine" + ) + + @stream_engine.after_stop + async def after_stop() -> None: + print("Finishing backgrpund tasks") + ``` + """ + self._after_stop.append(func) + return func + async def stop(self) -> None: + # Execute on_startup hooks + await execute_hooks(self._on_stop) + await self.monitor.stop() await self.stop_producer() await self.stop_streams() + # Execute after_startup hooks + await execute_hooks(self._after_stop) + async def stop_producer(self): if self._producer is not None: await self._producer.stop() diff --git a/kstreams/types.py b/kstreams/types.py index 1e6096be..469af45d 100644 --- a/kstreams/types.py +++ b/kstreams/types.py @@ -9,6 +9,7 @@ EncodedHeaders = typing.Sequence[typing.Tuple[str, bytes]] StreamFunc = typing.Callable NextMiddlewareCall = typing.Callable[[ConsumerRecord], typing.Awaitable[None]] +EngineHooks = typing.Sequence[typing.Callable[[], typing.Any]] class Send(typing.Protocol): diff --git a/kstreams/utils.py b/kstreams/utils.py index f1794fe6..1c431bef 100644 --- a/kstreams/utils.py +++ b/kstreams/utils.py @@ -1,4 +1,5 @@ import contextlib +import inspect import ssl from tempfile import NamedTemporaryFile from typing import Any, Optional, Union @@ -92,4 +93,12 @@ def create_ssl_context( ) +async def execute_hooks(hooks: types.EngineHooks) -> None: + for hook in hooks: + if inspect.iscoroutinefunction(hook): + await hook() + else: + hook() + + __all__ = ["create_ssl_context", "create_ssl_context_from_mem", "encode_headers"] diff --git a/tests/test_engine_hooks.py b/tests/test_engine_hooks.py new file mode 100644 index 00000000..307f493b --- /dev/null +++ b/tests/test_engine_hooks.py @@ -0,0 +1,189 @@ +import asyncio +from unittest import mock + +import pytest + +from kstreams import ConsumerRecord, StreamEngine +from kstreams.clients import Consumer, Producer + + +@pytest.mark.asyncio +async def test_hook_on_startup(stream_engine: StreamEngine, consumer_record_factory): + on_startup_sync_mock = mock.Mock() + on_startup_async_mock = mock.AsyncMock() + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + stop=mock.DEFAULT, + ), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): + assert stream_engine._on_startup == [] + + @stream_engine.stream("local--kstreams") + async def stream(cr: ConsumerRecord): + ... + + @stream_engine.on_startup + async def init_db(): + on_startup_sync_mock() + + @stream_engine.on_startup + async def backgound_task(): + await on_startup_async_mock() + + # check monitoring is not running + assert not stream_engine.monitor.running + + # check stream is not running + assert not stream.running + + assert stream_engine._on_startup == [init_db, backgound_task] + + # check that `on_startup` hooks were not called before + # `stream_engine.start()` was called + on_startup_sync_mock.assert_not_called() + on_startup_async_mock.assert_not_awaited() + + await stream_engine.start() + + # check that `on_startup` hooks were called + on_startup_sync_mock.assert_called_once() + on_startup_async_mock.assert_awaited_once() + + await stream_engine.stop() + + +@pytest.mark.asyncio +async def test_hook_after_startup(stream_engine: StreamEngine, consumer_record_factory): + after_startup_async_mock = mock.AsyncMock() + set_healthy_pod = mock.AsyncMock() + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + stop=mock.DEFAULT, + ), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): + assert stream_engine._after_startup == [] + + @stream_engine.stream("local--kstreams") + async def stream(cr: ConsumerRecord): + ... + + @stream_engine.after_startup + async def healthy(): + await set_healthy_pod() + + @stream_engine.after_startup + async def backgound_task(): + # give some time to start the tasks + await asyncio.sleep(0.1) + + await after_startup_async_mock() + + # check monitoring is running + assert stream_engine.monitor.running + + # check stream is running + assert stream.running + + assert stream_engine._after_startup == [healthy, backgound_task] + + # check that `after_startup` hooks were not called before + # `stream_engine.start()` was called + set_healthy_pod.assert_not_awaited() + after_startup_async_mock.assert_not_awaited() + + await stream_engine.start() + + # check that `after_startup` hooks were called + set_healthy_pod.assert_awaited_once() + after_startup_async_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_hook_on_stop(stream_engine: StreamEngine, consumer_record_factory): + close_db_mock = mock.Mock() + backgound_task_mock = mock.AsyncMock() + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + stop=mock.DEFAULT, + ), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): + assert stream_engine._on_stop == [] + + @stream_engine.stream("local--kstreams") + async def stream(cr: ConsumerRecord): + ... + + @stream_engine.on_stop + async def close_db(): + close_db_mock() + + @stream_engine.on_stop + async def stop_backgound_task(): + backgound_task_mock.cancel() + + # check that monitoring is running + assert stream_engine.monitor.running + + # check streams are running + assert stream.running + + assert stream_engine._on_stop == [close_db, stop_backgound_task] + await stream_engine.start() + + # give some time to start the tasks + await asyncio.sleep(0.1) + + # check that `on_stop` hooks were not called before + # `stream_engine.stop()` was called + close_db_mock.assert_not_called() + backgound_task_mock.cancel.assert_not_awaited() + + await stream_engine.stop() + + # check that `on_stop` hooks were called + close_db_mock.assert_called_once() + backgound_task_mock.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_hook_after_stop(stream_engine: StreamEngine, consumer_record_factory): + delete_files_mock = mock.AsyncMock() + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + stop=mock.DEFAULT, + ), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): + assert stream_engine._after_stop == [] + + @stream_engine.stream("local--kstreams") + async def stream(cr: ConsumerRecord): + ... + + @stream_engine.after_stop + async def delete_files(): + await delete_files_mock() + + # check that monitoring is not already running + assert not stream_engine.monitor.running + + # check streams are not running + assert not stream.running + + assert stream_engine._after_stop == [delete_files] + await stream_engine.start() + + # check that `after_stop` hooks were not called before + # `stream_engine.stop()` was called + delete_files_mock.assert_not_awaited() + + await stream_engine.stop() + + # give some time to start the tasks + await asyncio.sleep(0.1) + + # check that `after_stop` hooks were called + delete_files_mock.assert_awaited_once()