diff --git a/libs/langgraph/langgraph/pregel/executor.py b/libs/langgraph/langgraph/pregel/executor.py index 691098b7a..488f12661 100644 --- a/libs/langgraph/langgraph/pregel/executor.py +++ b/libs/langgraph/langgraph/pregel/executor.py @@ -86,19 +86,21 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: + # copy the tasks as done() callback may modify the dict + tasks = self.tasks.copy() # cancel all tasks that should be cancelled - for task, (cancel, _) in self.tasks.items(): + for task, (cancel, _) in tasks.items(): if cancel: task.cancel() # wait for all tasks to finish - if tasks := {t for t in self.tasks if not t.done()}: - concurrent.futures.wait(tasks) + if pending := {t for t in tasks if not t.done()}: + concurrent.futures.wait(pending) # shutdown the executor self.stack.__exit__(exc_type, exc_value, traceback) # re-raise the first exception that occurred in a task if exc_type is None: # if there's already an exception being raised, don't raise another one - for task, (_, reraise) in self.tasks.items(): + for task, (_, reraise) in tasks.items(): if not reraise: continue try: @@ -161,17 +163,19 @@ async def __aexit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: + # copy the tasks as done() callback may modify the dict + tasks = self.tasks.copy() # cancel all tasks that should be cancelled - for task, (cancel, _) in self.tasks.items(): + for task, (cancel, _) in tasks.items(): if cancel: task.cancel(self.sentinel) # wait for all tasks to finish - if self.tasks: - await asyncio.wait(self.tasks) + if tasks: + await asyncio.wait(tasks) # if there's already an exception being raised, don't raise another one if exc_type is None: # re-raise the first exception that occurred in a task - for task, (_, reraise) in self.tasks.items(): + for task, (_, reraise) in tasks.items(): if not reraise: continue try: