@@ -103,6 +103,7 @@ def __init__(
103
103
self .name = name or str (uuid .uuid4 ())
104
104
self .deserializer = deserializer
105
105
self .running = False
106
+ self .is_processing = asyncio .Lock ()
106
107
self .initial_offsets = initial_offsets
107
108
self .seeked_initial_offsets = False
108
109
self .rebalance_listener = rebalance_listener
@@ -121,15 +122,18 @@ def _create_consumer(self) -> Consumer:
121
122
return self .consumer_class (** config )
122
123
123
124
async def stop (self ) -> None :
124
- if not self .running :
125
- return None
126
-
127
- if self .consumer is not None :
128
- await self .consumer .stop ()
125
+ if self .running :
126
+ # Don't run anymore to prevent new events comming
129
127
self .running = False
130
128
131
- if self ._consumer_task is not None :
132
- self ._consumer_task .cancel ()
129
+ async with self .is_processing :
130
+ # Only enter this block when all the events have been
131
+ # proccessed in the middleware chain
132
+ if self .consumer is not None :
133
+ await self .consumer .stop ()
134
+
135
+ if self ._consumer_task is not None :
136
+ self ._consumer_task .cancel ()
133
137
134
138
async def _subscribe (self ) -> None :
135
139
# Always create a consumer on stream.start
@@ -141,7 +145,6 @@ async def _subscribe(self) -> None:
141
145
self .consumer .subscribe (
142
146
topics = self .topics , listener = self .rebalance_listener
143
147
)
144
- self .running = True
145
148
146
149
async def commit (
147
150
self , offsets : typing .Optional [typing .Dict [TopicPartition , int ]] = None
@@ -206,6 +209,7 @@ async def start(self) -> None:
206
209
return None
207
210
208
211
await self ._subscribe ()
212
+ self .running = True
209
213
210
214
if self .udf_handler .type == UDFType .NO_TYPING :
211
215
# normal use case
@@ -236,9 +240,10 @@ async def func_wrapper(self, func: typing.Awaitable) -> None:
236
240
logger .exception (f"CRASHED Stream!!! Task { self ._consumer_task } \n \n { e } " )
237
241
238
242
async def func_wrapper_with_typing (self ) -> None :
239
- while True :
243
+ while self . running :
240
244
cr = await self .getone ()
241
- await self .func (cr )
245
+ async with self .is_processing :
246
+ await self .func (cr )
242
247
243
248
def seek_to_initial_offsets (self ) -> None :
244
249
if not self .seeked_initial_offsets and self .consumer is not None :
0 commit comments