From c32b4fd953794a28b96fc20059c72701137ffd57 Mon Sep 17 00:00:00 2001
From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com>
Date: Thu, 12 Dec 2024 03:54:13 -0400
Subject: [PATCH] feat: optimize asyncio.Future.exception (#469)

---
 a_sync/_smart.pyx              | 10 +++----
 a_sync/a_sync/function.pyx     |  9 +++---
 a_sync/asyncio/create_task.pyx | 50 +++++++++++++++++++++++-----------
 3 files changed, 44 insertions(+), 25 deletions(-)

diff --git a/a_sync/_smart.pyx b/a_sync/_smart.pyx
index 9d6b23f5..a782d5e8 100644
--- a/a_sync/_smart.pyx
+++ b/a_sync/_smart.pyx
@@ -152,7 +152,7 @@ cdef inline bint _is_not_done(fut: asyncio.Future):
     """
     return <str>fut._state == "PENDING"
 
-cdef inline bint cancelled(fut: asyncio.Future):
+cdef inline bint _is_cancelled(fut: asyncio.Future):
     """Return True if the future was cancelled."""
     return <str>fut._state == "CANCELLED"
 
@@ -173,7 +173,7 @@ cdef object _get_result(fut: asyncio.Future):
         raise fut._make_cancelled_error()
     raise asyncio.exceptions.InvalidStateError('Result is not ready.')
 
-def _get_exception(fut: asyncio.Future):
+cdef object _get_exception(fut: asyncio.Future):
     """Return the exception that was set on this future.
 
     The exception (or None if no exception was set) is returned only if
@@ -594,13 +594,13 @@ def shield(
         waiters.add(outer)
 
     def _inner_done_callback(inner):
-        if cancelled(outer):
-            if not cancelled(inner):
+        if _is_cancelled(outer):
+            if not _is_cancelled(inner):
                 # Mark inner's result as retrieved.
                 _get_exception(inner)
             return
 
-        if cancelled(inner):
+        if _is_cancelled(inner):
             outer.cancel()
         else:
             exc = _get_exception(inner)
diff --git a/a_sync/a_sync/function.pyx b/a_sync/a_sync/function.pyx
index 3360db9b..8d62d434 100644
--- a/a_sync/a_sync/function.pyx
+++ b/a_sync/a_sync/function.pyx
@@ -2,6 +2,7 @@ import functools
 import inspect
 import logging
 import sys
+from libc.stdint cimport uintptr_t
 
 from async_lru import _LRUCacheWrapper
 from async_property.base import AsyncPropertyDescriptor  # type: ignore [import]
@@ -120,10 +121,10 @@ cpdef void _validate_wrapped_fn(fn: Callable):
 
 cdef object _function_type = type(logging.getLogger)
 
-cdef set[Py_ssize_t] _argspec_validated = set()
+cdef set[uintptr_t] _argspec_validated = set()
 
 cdef void _validate_argspec_cached(fn: Callable):
-    cdef Py_ssize_t fid = id(fn)
+    cdef uintptr_t fid = id(fn)
     if fid not in _argspec_validated:
         _validate_argspec(fn)
         _argspec_validated.add(fid)
@@ -971,10 +972,10 @@ class ASyncDecorator(_ModifiedMixin):
             return ASyncFunctionSyncDefault(func, **self.modifiers)
 
 
-cdef set[Py_ssize_t] _is_genfunc_cache = set()
+cdef set[uintptr_t] _is_genfunc_cache = set()
 
 cdef void _check_not_genfunc_cached(func: Callable):
-    cdef Py_ssize_t fid = id(func)
+    cdef uintptr_t fid = id(func)
     if fid not in _is_genfunc_cache:
         _check_not_genfunc(func)
         _is_genfunc_cache.add(fid)
diff --git a/a_sync/asyncio/create_task.pyx b/a_sync/asyncio/create_task.pyx
index f7af0a31..1cc6989a 100644
--- a/a_sync/asyncio/create_task.pyx
+++ b/a_sync/asyncio/create_task.pyx
@@ -134,28 +134,46 @@ cdef void __prune_persisted_tasks():
     cdef object task
     cdef dict context
     for task in tuple(__persisted_tasks):
-        if _is_done(task) and (e := task.exception()):
-            # force exceptions related to this lib to bubble up
-            if not isinstance(e, exceptions.PersistedTaskException):
-                c_logger.exception(e)
-                raise e
-            # we have to manually log the traceback that asyncio would usually log
-            # since we already got the exception from the task and the usual handler will now not run
-            context = {
-                "message": f"{task.__class__.__name__} exception was never retrieved",
-                "exception": e,
-                "future": task,
-            }
-            if task._source_traceback:
-                context["source_traceback"] = task._source_traceback
-            task._loop.call_exception_handler(context)
-            __persisted_tasks.discard(task)
+        if _is_done(task):
+            if e := _get_exception(task):
+                # force exceptions related to this lib to bubble up
+                if not isinstance(e, exceptions.PersistedTaskException):
+                    c_logger.exception(e)
+                    raise e
+                # we have to manually log the traceback that asyncio would usually log
+                # since we already got the exception from the task and the usual handler will now not run
+                context = {
+                    "message": f"{task.__class__.__name__} exception was never retrieved",
+                    "exception": e,
+                    "future": task,
+                }
+                if task._source_traceback:
+                    context["source_traceback"] = task._source_traceback
+                task._loop.call_exception_handler(context)
+                __persisted_tasks.discard(task)
 
 
 cdef inline bint _is_done(fut: asyncio.Future):
     return <str>fut._state != "PENDING"
 
 
+cdef object _get_exception(fut: asyncio.Future):
+    """Return the exception that was set on this future.
+
+    The exception (or None if no exception was set) is returned only if
+    the future is done.  If the future has been cancelled, raises
+    CancelledError.  If the future isn't done yet, raises
+    InvalidStateError.
+    """
+    cdef str state = fut._state
+    if state == "FINISHED":
+        fut._Future__log_traceback = False
+        return fut._exception
+    if state == "CANCELLED":
+        raise fut._make_cancelled_error()
+    raise asyncio.exceptions.InvalidStateError('Exception is not set.')
+
+
 async def __persisted_task_exc_wrap(task: "asyncio.Task[T]") -> T:
     """
     Wrap a task to handle its exception in a specialized manner.