Skip to content

Commit e013ae5

Browse files
committed
fix(StreamEngine): graceful shutdown must wait for all events to be processed before Streams are stopped. Related to #162
1 parent ca70959 commit e013ae5

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

kstreams/streams.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
self.name = name or str(uuid.uuid4())
104104
self.deserializer = deserializer
105105
self.running = False
106+
self.is_processing = asyncio.Lock()
106107
self.initial_offsets = initial_offsets
107108
self.seeked_initial_offsets = False
108109
self.rebalance_listener = rebalance_listener
@@ -121,15 +122,18 @@ def _create_consumer(self) -> Consumer:
121122
return self.consumer_class(**config)
122123

123124
async def stop(self) -> None:
124-
if not self.running:
125-
return None
126-
127-
if self.consumer is not None:
128-
await self.consumer.stop()
125+
if self.running:
126+
# Don't run anymore to prevent new events comming
129127
self.running = False
130128

131-
if self._consumer_task is not None:
132-
self._consumer_task.cancel()
129+
async with self.is_processing:
130+
# Only enter this block when all the events have been
131+
# proccessed in the middleware chain
132+
if self.consumer is not None:
133+
await self.consumer.stop()
134+
135+
if self._consumer_task is not None:
136+
self._consumer_task.cancel()
133137

134138
async def _subscribe(self) -> None:
135139
# Always create a consumer on stream.start
@@ -141,7 +145,6 @@ async def _subscribe(self) -> None:
141145
self.consumer.subscribe(
142146
topics=self.topics, listener=self.rebalance_listener
143147
)
144-
self.running = True
145148

146149
async def commit(
147150
self, offsets: typing.Optional[typing.Dict[TopicPartition, int]] = None
@@ -206,6 +209,7 @@ async def start(self) -> None:
206209
return None
207210

208211
await self._subscribe()
212+
self.running = True
209213

210214
if self.udf_handler.type == UDFType.NO_TYPING:
211215
# normal use case
@@ -236,9 +240,10 @@ async def func_wrapper(self, func: typing.Awaitable) -> None:
236240
logger.exception(f"CRASHED Stream!!! Task {self._consumer_task} \n\n {e}")
237241

238242
async def func_wrapper_with_typing(self) -> None:
239-
while True:
243+
while self.running:
240244
cr = await self.getone()
241-
await self.func(cr)
245+
async with self.is_processing:
246+
await self.func(cr)
242247

243248
def seek_to_initial_offsets(self) -> None:
244249
if not self.seeked_initial_offsets and self.consumer is not None:

kstreams/test_utils/test_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ class TestMonitor(PrometheusMonitor):
1717
__test__ = False
1818

1919
def start(self, *args, **kwargs) -> None:
20-
print("herte....")
21-
# ...
20+
...
2221

2322
async def stop(self, *args, **kwargs) -> None:
2423
...

tests/test_stream_engine.py

+33
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,39 @@ async def stream(_):
264264
Consumer.stop.assert_awaited()
265265

266266

267+
@pytest.mark.asyncio
268+
async def test_wait_for_streams_before_stop(
269+
stream_engine: StreamEngine, consumer_record_factory: Callable[..., ConsumerRecord]
270+
):
271+
topic = "local--hello-kpn"
272+
value = b"Hello world"
273+
save_to_db = mock.AsyncMock()
274+
275+
async def getone(_):
276+
return consumer_record_factory(value=value)
277+
278+
@stream_engine.stream(topic)
279+
async def stream(cr: ConsumerRecord):
280+
# Use 5 seconds sleep to simulate a super slow event processing
281+
await asyncio.sleep(5)
282+
await save_to_db(cr.value)
283+
284+
with mock.patch.multiple(
285+
Consumer,
286+
start=mock.DEFAULT,
287+
stop=mock.DEFAULT,
288+
getone=getone,
289+
), mock.patch.multiple(Producer, start=mock.DEFAULT, stop=mock.DEFAULT):
290+
await stream_engine.start()
291+
await asyncio.sleep(0) # Allow stream coroutine to run once
292+
293+
# stop engine immediately, this should not break the streams
294+
# and it should wait until the event is processed.
295+
await stream_engine.stop()
296+
Consumer.stop.assert_awaited()
297+
save_to_db.assert_awaited_once_with(value)
298+
299+
267300
@pytest.mark.asyncio
268301
async def test_recreate_consumer_on_re_start_stream(stream_engine: StreamEngine):
269302
with mock.patch("kstreams.clients.aiokafka.AIOKafkaConsumer.start"):

0 commit comments

Comments
 (0)