diff --git a/src/prefect/_internal/concurrency/calls.py b/src/prefect/_internal/concurrency/calls.py index 13e9e9bbfc61..6ca4dc45567e 100644 --- a/src/prefect/_internal/concurrency/calls.py +++ b/src/prefect/_internal/concurrency/calls.py @@ -10,6 +10,7 @@ import dataclasses import inspect import threading +import weakref from concurrent.futures._base import ( CANCELLED, CANCELLED_AND_NOTIFIED, @@ -32,21 +33,30 @@ T = TypeVar("T") P = ParamSpec("P") - -# Tracks the current call being executed -current_call: contextvars.ContextVar["Call"] = contextvars.ContextVar("current_call") +# Tracks the current call being executed. Note that storing the `Call` +# object for an async call directly in the contextvar appears to create a +# memory leak, despite the fact that we `reset` when leaving the context +# that sets this contextvar. A weakref avoids the leak and works because a) +# we already have strong references to the `Call` objects in other places +# and b) this is used for performance optimizations where we have fallback +# behavior if this weakref is garbage collected. A fix for issue #10952. +current_call: contextvars.ContextVar["weakref.ref[Call]"] = contextvars.ContextVar( + "current_call" +) # Create a strong reference to tasks to prevent destruction during execution errors _ASYNC_TASK_REFS = set() def get_current_call() -> Optional["Call"]: - return current_call.get(None) + call_ref = current_call.get(None) + if call_ref: + return call_ref() @contextlib.contextmanager def set_current_call(call: "Call"): - token = current_call.set(call) + token = current_call.set(weakref.ref(call)) try: yield finally: @@ -181,6 +191,29 @@ def result(self, timeout=None): # Break a reference cycle with the exception in self._exception self = None + def _invoke_callbacks(self): + """ + Invoke our done callbacks and clean up cancel scopes and cancel + callbacks. Fixes a memory leak that hung on to Call objects, + preventing garbage collection of Futures. + + A fix for #10952. + """ + if self._done_callbacks: + done_callbacks = self._done_callbacks[:] + self._done_callbacks[:] = [] + + for callback in done_callbacks: + try: + callback(self) + except Exception: + logger.exception("exception calling callback for %r", self) + + self._cancel_callbacks = [] + if self._cancel_scope: + self._cancel_scope._callbacks = [] + self._cancel_scope = None + @dataclasses.dataclass class Call(Generic[T]): diff --git a/src/prefect/_internal/concurrency/cancellation.py b/src/prefect/_internal/concurrency/cancellation.py index 8dcb4e6519de..25c2c2b5ad9b 100644 --- a/src/prefect/_internal/concurrency/cancellation.py +++ b/src/prefect/_internal/concurrency/cancellation.py @@ -270,6 +270,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Mark as cancelled self.cancel(throw=False) + # TODO: Can we also delete the scope? + # We have to exit this scope to prevent leaking memory. A fix for + # issue #10952. + self._anyio_scope.__exit__(exc_type, exc_val, exc_tb) + super().__exit__(exc_type, exc_val, exc_tb) if self.cancelled() and exc_type is not CancelledError: