@@ -181,6 +181,8 @@ def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
181
181
# Event loop needs to remain in the same process
182
182
self ._task_for_pid : Optional [int ] = None
183
183
self ._loop : Optional [asyncio .AbstractEventLoop ] = None
184
+ # Track active callback tasks so they have a strong reference and can be cancelled on kill
185
+ self ._active_tasks : set [asyncio .Task ] = set ()
184
186
185
187
@property
186
188
def is_alive (self ) -> bool :
@@ -195,6 +197,12 @@ def kill(self) -> None:
195
197
self ._task .cancel ()
196
198
self ._task = None
197
199
self ._task_for_pid = None
200
+ # Also cancel any active callback tasks
201
+ # Avoid modifying the set while cancelling tasks
202
+ tasks_to_cancel = set (self ._active_tasks )
203
+ for task in tasks_to_cancel :
204
+ task .cancel ()
205
+ self ._active_tasks .clear ()
198
206
self ._loop = None
199
207
200
208
def start (self ) -> None :
@@ -256,16 +264,30 @@ def submit(self, callback: Callable[[], Any]) -> bool:
256
264
async def _target (self ) -> None :
257
265
while True :
258
266
callback = await self ._queue .get ()
259
- try :
260
- if inspect .iscoroutinefunction (callback ):
261
- # Callback is an async coroutine, need to await it
262
- await callback ()
263
- else :
264
- # Callback is a sync function, need to call it
265
- callback ()
266
- except Exception :
267
- logger .error ("Failed processing job" , exc_info = True )
268
- finally :
269
- self ._queue .task_done ()
267
+ # Firing tasks instead of awaiting them allows for concurrent requests
268
+ task = asyncio .create_task (self ._process_callback (callback ))
269
+ # Create a strong reference to the task so it can be cancelled on kill
270
+ # and does not get garbage collected while running
271
+ self ._active_tasks .add (task )
272
+ task .add_done_callback (self ._on_task_complete )
270
273
# Yield to let the event loop run other tasks
271
274
await asyncio .sleep (0 )
275
+
276
+ async def _process_callback (self , callback : Callable [[], Any ]) -> None :
277
+ if inspect .iscoroutinefunction (callback ):
278
+ # Callback is an async coroutine, need to await it
279
+ await callback ()
280
+ else :
281
+ # Callback is a sync function, need to call it
282
+ callback ()
283
+
284
+ def _on_task_complete (self , task : asyncio .Task [None ]) -> None :
285
+ try :
286
+ task .result ()
287
+ except Exception :
288
+ logger .error ("Failed processing job" , exc_info = True )
289
+ finally :
290
+ # Mark the task as done and remove it from the active tasks set
291
+ # This happens only after the task has completed
292
+ self ._queue .task_done ()
293
+ self ._active_tasks .discard (task )
0 commit comments