Skip to content

Commit

Permalink
feat: cythonize _smart.py (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Dec 12, 2024
1 parent 372451b commit 421929e
Showing 1 changed file with 87 additions and 21 deletions.
108 changes: 87 additions & 21 deletions a_sync/_smart.py → a_sync/_smart.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

logger = logging.getLogger(__name__)

cdef Py_ssize_t ZERO = 0
cdef Py_ssize_t ONE = 1

class _SmartFutureMixin(Generic[T]):
"""
Expand Down Expand Up @@ -80,8 +82,8 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T
result = await task
```
"""
if self._state != "PENDING":
return self.result() # May raise too.
if _is_done(self):
return _get_result(self) # May raise too.

self._asyncio_future_blocking = True
if current_task := asyncio.current_task(self._loop):
Expand All @@ -92,12 +94,12 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T

logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
if self._state == "PENDING":
if _is_not_done(self):
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
return _get_result(self) # May raise too.

@property
def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int:
def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> Py_ssize_t:
"""
Get the number of waiters currently awaiting the future or task.

Expand All @@ -113,9 +115,7 @@ def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int:
See Also:
- :meth:`_waiter_done_cleanup_callback`
"""
if self._state != "PENDING":
return 0
return sum(getattr(waiter, "num_waiters", 1) for waiter in self._waiters)
return count_waiters(self)

def _waiter_done_cleanup_callback(
self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask"
Expand All @@ -131,7 +131,7 @@ def _waiter_done_cleanup_callback(
Example:
Automatically called when a waiter task completes.
"""
if self._state == "PENDING":
if _is_not_done(self):
self._waiters.remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
Expand All @@ -145,6 +145,72 @@ def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None
queue._futs.pop(self._key)


cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]):
if _is_done(fut):
return ZERO
try:
waiters = fut._waiters
except AttributeError:
return ONE
cdef Py_ssize_t count = 0
for waiter in waiters:
count += count_waiters(waiter)
return count


cdef inline bint _is_done(fut: asyncio.Future):
"""Return True if the future is done.
Done means either that a result / exception are available, or that the
future was cancelled.
"""
return <str>fut._state != "PENDING"

cdef inline bint _is_not_done(fut: asyncio.Future):
"""Return False if the future is done.
Done means either that a result / exception are available, or that the
future was cancelled.
"""
return <str>fut._state == "PENDING"

cdef inline bint cancelled(fut: asyncio.Future):
"""Return True if the future was cancelled."""
return <str>fut._state == "CANCELLED"

cdef object _get_result(fut: asyncio.Future):
"""Return the result this future represents.
If the future has been cancelled, raises CancelledError. If the
future's result isn't yet available, raises InvalidStateError. If
the future is done and has an exception set, this exception is raised.
"""
cdef str state = fut._state
if state == "FINISHED":
fut._Future__log_traceback = False
if fut._exception is not None:
raise fut._exception.with_traceback(fut._exception_tb)
return fut._result
if state == "CANCELLED":
raise fut._make_cancelled_error()
raise asyncio.exceptions.InvalidStateError('Result is not ready.')

def _get_exception(fut: asyncio.Future):
"""Return the exception that was set on this future.
The exception (or None if no exception was set) is returned only if
the future is done. If the future has been cancelled, raises
CancelledError. If the future isn't done yet, raises
InvalidStateError.
"""
cdef str state = fut._state
if state == "FINISHED":
fut._Future__log_traceback = False
return fut._exception
if state == "CANCELLED":
raise fut._make_cancelled_error()
raise asyncio.exceptions.InvalidStateError('Exception is not set.')

class SmartFuture(_SmartFutureMixin[T], asyncio.Future):
"""
A smart future that tracks waiters and integrates with a smart processing queue.
Expand Down Expand Up @@ -200,9 +266,9 @@ def __init__(
self.add_done_callback(SmartFuture._self_done_cleanup_callback)

def __repr__(self):
return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>"
return f"<{<str>type(self).__name__} key={self._key} waiters={count_waiters(self)} {<str>self._state}>"

def __lt__(self, other: "SmartFuture[T]") -> bool:
def __lt__(self, other: "SmartFuture[T]") -> bint:
"""
Compare the number of waiters to determine priority in a heap.
Lower values indicate higher priority, so more waiters means 'less than'.
Expand All @@ -220,7 +286,7 @@ def __lt__(self, other: "SmartFuture[T]") -> bool:
See Also:
- :meth:`num_waiters`
"""
return self.num_waiters > other.num_waiters
return count_waiters(self) > count_waiters(other)


def create_future(
Expand Down Expand Up @@ -297,7 +363,7 @@ def __init__(
- :func:`asyncio.create_task`
"""
asyncio.Task.__init__(self, coro, loop=loop, name=name)
self._waiters: Set["asyncio.Task[T]"] = set()
self._waiters: Set["asyncio.Task[T]"] = <set>set()
self.add_done_callback(SmartTask._self_done_cleanup_callback)


Expand Down Expand Up @@ -405,7 +471,7 @@ def shield(
stacklevel=2,
)
inner = asyncio.ensure_future(arg, loop=loop)
if inner._state != "PENDING":
if _is_done(inner):
# Shortcut.
return inner
loop = asyncio.futures._get_loop(inner)
Expand All @@ -415,23 +481,23 @@ def shield(
waiters.add(outer)

def _inner_done_callback(inner):
if outer.cancelled():
if not inner.cancelled():
if cancelled(outer):
if not cancelled(inner):
# Mark inner's result as retrieved.
inner.exception()
_get_exception(inner)
return

if inner.cancelled():
if cancelled(inner):
outer.cancel()
else:
exc = inner.exception()
exc = _get_exception(inner)
if exc is not None:
outer.set_exception(exc)
else:
outer.set_result(inner.result())
outer.set_result(_get_result(inner))

def _outer_done_callback(outer):
if inner._state == "PENDING":
if _is_not_done(inner):
inner.remove_done_callback(_inner_done_callback)

inner.add_done_callback(_inner_done_callback)
Expand Down

0 comments on commit 421929e

Please sign in to comment.