diff --git a/kstreams/streams.py b/kstreams/streams.py index e7f27920..6af48cff 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -103,6 +103,7 @@ def __init__( self.name = name or str(uuid.uuid4()) self.deserializer = deserializer self.running = False + self.is_processing = asyncio.Condition() self.initial_offsets = initial_offsets self.seeked_initial_offsets = False self.rebalance_listener = rebalance_listener @@ -121,15 +122,18 @@ def _create_consumer(self) -> Consumer: return self.consumer_class(**config) async def stop(self) -> None: - if not self.running: - return None - - if self.consumer is not None: - await self.consumer.stop() + if self.running: + # Don't run anymore to prevent new events comming self.running = False - if self._consumer_task is not None: - self._consumer_task.cancel() + async with self.is_processing: + # Only enter this block when all the events have been + # proccessed in the middleware chain + if self.consumer is not None: + await self.consumer.stop() + + if self._consumer_task is not None: + self._consumer_task.cancel() async def _subscribe(self) -> None: # Always create a consumer on stream.start @@ -141,7 +145,6 @@ async def _subscribe(self) -> None: self.consumer.subscribe( topics=self.topics, listener=self.rebalance_listener ) - self.running = True async def commit( self, offsets: typing.Optional[typing.Dict[TopicPartition, int]] = None @@ -206,6 +209,7 @@ async def start(self) -> None: return None await self._subscribe() + self.running = True if self.udf_handler.type == UDFType.NO_TYPING: # normal use case @@ -236,9 +240,10 @@ async def func_wrapper(self, func: typing.Awaitable) -> None: logger.exception(f"CRASHED Stream!!! Task {self._consumer_task} \n\n {e}") async def func_wrapper_with_typing(self) -> None: - while True: + while self.running: cr = await self.getone() - await self.func(cr) + async with self.is_processing: + await self.func(cr) def seek_to_initial_offsets(self) -> None: if not self.seeked_initial_offsets and self.consumer is not None: diff --git a/kstreams/test_utils/test_utils.py b/kstreams/test_utils/test_utils.py index 9059514c..4beb75cb 100644 --- a/kstreams/test_utils/test_utils.py +++ b/kstreams/test_utils/test_utils.py @@ -17,8 +17,7 @@ class TestMonitor(PrometheusMonitor): __test__ = False def start(self, *args, **kwargs) -> None: - print("herte....") - # ... + ... async def stop(self, *args, **kwargs) -> None: ... diff --git a/tests/test_stream_engine.py b/tests/test_stream_engine.py index 52ece891..8f68268c 100644 --- a/tests/test_stream_engine.py +++ b/tests/test_stream_engine.py @@ -264,6 +264,39 @@ async def stream(_): Consumer.stop.assert_awaited() +@pytest.mark.asyncio +async def test_wait_for_streams_before_stop( + stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord] +): + topic = "local--hello-kpn" + value = b"Hello world" + save_to_db = mock.AsyncMock() + + async def getone(_): + return consumer_record_factory(value=value) + + @stream_engine.stream(topic) + async def stream(cr: ConsumerRecord): + # Use 5 seconds sleep to simulate a super slow event processing + await asyncio.sleep(5) + await save_to_db(cr.value) + + with mock.patch.multiple( + Consumer, + start=mock.DEFAULT, + stop=mock.DEFAULT, + getone=getone, + ), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT): + await stream_engine.start() + await asyncio.sleep(0) # Allow stream coroutine to run once + + # stop engine immediately, this should not break the streams + # and it should wait until the event is processed. + await stream_engine.stop() + Consumer.stop.assert_awaited() + save_to_db.assert_awaited_once_with(value) + + @pytest.mark.asyncio async def test_recreate_consumer_on_re_start_stream(stream_engine: StreamEngine): with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"):