diff --git a/a_sync/_smart.py b/a_sync/_smart.pyx similarity index 79% rename from a_sync/_smart.py rename to a_sync/_smart.pyx index 84d6bad2..edac1ece 100644 --- a/a_sync/_smart.py +++ b/a_sync/_smart.pyx @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) +cdef Py_ssize_t ZERO = 0 +cdef Py_ssize_t ONE = 1 class _SmartFutureMixin(Generic[T]): """ @@ -80,8 +82,8 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T result = await task ``` """ - if self._state != "PENDING": - return self.result() # May raise too. + if _is_done(self): + return _get_result(self) # May raise too. self._asyncio_future_blocking = True if current_task := asyncio.current_task(self._loop): @@ -92,12 +94,12 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T logger.debug("awaiting %s", self) yield self # This tells Task to wait for completion. - if self._state == "PENDING": + if _is_not_done(self): raise RuntimeError("await wasn't used with future") - return self.result() # May raise too. + return _get_result(self) # May raise too. @property - def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int: + def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> Py_ssize_t: """ Get the number of waiters currently awaiting the future or task. @@ -113,9 +115,7 @@ def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int: See Also: - :meth:`_waiter_done_cleanup_callback` """ - if self._state != "PENDING": - return 0 - return sum(getattr(waiter, "num_waiters", 1) for waiter in self._waiters) + return count_waiters(self) def _waiter_done_cleanup_callback( self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask" @@ -131,7 +131,7 @@ def _waiter_done_cleanup_callback( Example: Automatically called when a waiter task completes. """ - if self._state == "PENDING": + if _is_not_done(self): self._waiters.remove(waiter) def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None: @@ -145,6 +145,72 @@ def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None queue._futs.pop(self._key) +cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]): + if _is_done(fut): + return ZERO + try: + waiters = fut._waiters + except AttributeError: + return ONE + cdef Py_ssize_t count = 0 + for waiter in waiters: + count += count_waiters(waiter) + return count + + +cdef inline bint _is_done(fut: asyncio.Future): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return fut._state != "PENDING" + +cdef inline bint _is_not_done(fut: asyncio.Future): + """Return False if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return fut._state == "PENDING" + +cdef inline bint cancelled(fut: asyncio.Future): + """Return True if the future was cancelled.""" + return fut._state == "CANCELLED" + +cdef object _get_result(fut: asyncio.Future): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + cdef str state = fut._state + if state == "FINISHED": + fut._Future__log_traceback = False + if fut._exception is not None: + raise fut._exception.with_traceback(fut._exception_tb) + return fut._result + if state == "CANCELLED": + raise fut._make_cancelled_error() + raise asyncio.exceptions.InvalidStateError('Result is not ready.') + +def _get_exception(fut: asyncio.Future): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + cdef str state = fut._state + if state == "FINISHED": + fut._Future__log_traceback = False + return fut._exception + if state == "CANCELLED": + raise fut._make_cancelled_error() + raise asyncio.exceptions.InvalidStateError('Exception is not set.') + class SmartFuture(_SmartFutureMixin[T], asyncio.Future): """ A smart future that tracks waiters and integrates with a smart processing queue. @@ -200,9 +266,9 @@ def __init__( self.add_done_callback(SmartFuture._self_done_cleanup_callback) def __repr__(self): - return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>" + return f"<{type(self).__name__} key={self._key} waiters={count_waiters(self)} {self._state}>" - def __lt__(self, other: "SmartFuture[T]") -> bool: + def __lt__(self, other: "SmartFuture[T]") -> bint: """ Compare the number of waiters to determine priority in a heap. Lower values indicate higher priority, so more waiters means 'less than'. @@ -220,7 +286,7 @@ def __lt__(self, other: "SmartFuture[T]") -> bool: See Also: - :meth:`num_waiters` """ - return self.num_waiters > other.num_waiters + return count_waiters(self) > count_waiters(other) def create_future( @@ -297,7 +363,7 @@ def __init__( - :func:`asyncio.create_task` """ asyncio.Task.__init__(self, coro, loop=loop, name=name) - self._waiters: Set["asyncio.Task[T]"] = set() + self._waiters: Set["asyncio.Task[T]"] = set() self.add_done_callback(SmartTask._self_done_cleanup_callback) @@ -405,7 +471,7 @@ def shield( stacklevel=2, ) inner = asyncio.ensure_future(arg, loop=loop) - if inner._state != "PENDING": + if _is_done(inner): # Shortcut. return inner loop = asyncio.futures._get_loop(inner) @@ -415,23 +481,23 @@ def shield( waiters.add(outer) def _inner_done_callback(inner): - if outer.cancelled(): - if not inner.cancelled(): + if cancelled(outer): + if not cancelled(inner): # Mark inner's result as retrieved. - inner.exception() + _get_exception(inner) return - if inner.cancelled(): + if cancelled(inner): outer.cancel() else: - exc = inner.exception() + exc = _get_exception(inner) if exc is not None: outer.set_exception(exc) else: - outer.set_result(inner.result()) + outer.set_result(_get_result(inner)) def _outer_done_callback(outer): - if inner._state == "PENDING": + if _is_not_done(inner): inner.remove_done_callback(_inner_done_callback) inner.add_done_callback(_inner_done_callback)