Skip to content

Commit

Permalink
feat: optimize SmartFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Dec 12, 2024
1 parent 421929e commit e6fab0d
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 73 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ docs:
sphinx-apidoc --private -o ./docs/source ./a_sync

cython:
python csetup.py build_ext --inplace
python setup.py build_ext --inplace

stubs:
stubgen ./a_sync -o . --include-docstrings
269 changes: 197 additions & 72 deletions a_sync/_smart.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import asyncio
import logging
import warnings
import weakref
from libc.stdint cimport uintptr_t

import a_sync.asyncio
from a_sync._typing import *
Expand Down Expand Up @@ -57,47 +58,6 @@ class _SmartFutureMixin(Generic[T]):
_key: _Key
_waiters: "weakref.WeakSet[SmartTask[T]]"

def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]:
"""
Await the smart future or task, handling waiters and logging.

Yields:
The result of the future or task.

Raises:
RuntimeError: If await wasn't used with future.

Example:
Awaiting a SmartFuture:

```python
future = SmartFuture()
result = await future
```

Awaiting a SmartTask:

```python
task = SmartTask(coro=my_coroutine())
result = await task
```
"""
if _is_done(self):
return _get_result(self) # May raise too.

self._asyncio_future_blocking = True
if current_task := asyncio.current_task(self._loop):
self._waiters.add(current_task)
current_task.add_done_callback(
self._waiter_done_cleanup_callback # type: ignore [union-attr]
)

logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
if _is_not_done(self):
raise RuntimeError("await wasn't used with future")
return _get_result(self) # May raise too.

@property
def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> Py_ssize_t:
"""
Expand All @@ -117,47 +77,65 @@ class _SmartFutureMixin(Generic[T]):
"""
return count_waiters(self)

def _waiter_done_cleanup_callback(
self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask"
) -> None:
"""
Callback to clean up waiters when a waiter task is done.

Removes the waiter from _waiters, and _queue._futs if applicable.

Args:
waiter: The waiter task to clean up.

Example:
Automatically called when a waiter task completes.
"""
if _is_not_done(self):
self._waiters.remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
"""
Callback to clean up waiters and remove the future from the queue when done.

This method clears all waiters and removes the future from the associated queue.
"""
self._waiters.clear()
if queue := self._queue:
queue._futs.pop(self._key)


cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]):
cdef WeakSet waiters
if _is_done(fut):
return ZERO
try:
waiters = fut._waiters
except AttributeError:
return ONE
cdef Py_ssize_t count = 0
for waiter in waiters:
for waiter in waiters.iter():
count += count_waiters(waiter)
return count


cdef class WeakSet:
_refs: dict[uintptr_t, object]
"""Mapping from object ID to weak reference."""

def __cinit__(self):
self._refs = {}

def _gc_callback(self, fut: asyncio.Future) -> None:
# Callback when a weakly-referenced object is garbage collected
self._refs.pop(<uintptr_t>id(fut), None) # Safely remove the item if it exists

cdef void add(self, fut: asyncio.Future):
# Keep a weak reference with a callback for when the item is collected
ref = weakref.ref(fut, self._gc_callback)
self._refs[<uintptr_t>id(fut)] = ref

cdef void remove(self, fut: asyncio.Future):
# Keep a weak reference with a callback for when the item is collected
try:
self._refs.pop(<uintptr_t>id(fut))
except KeyError:
raise KeyError(fut) from None

def __len__(self) -> int:
return len(self._refs)

def __bool__(self) -> bool:
return bool(self._refs)

def __contains__(self, item: asyncio.Future) -> bool:
ref = self._refs.get(<uintptr_t>id(item))
return ref is not None and ref() is item

def __iter__(self):
for ref in self._refs.values():
item = ref()
if item is not None:
yield item

def __repr__(self):
# Use list comprehension syntax within the repr function for clarity
return f"WeakSet({', '.join(repr(item) for item in self)})"


cdef inline bint _is_done(fut: asyncio.Future):
"""Return True if the future is done.
Expand Down Expand Up @@ -260,10 +238,10 @@ class SmartFuture(_SmartFutureMixin[T], asyncio.Future):
super().__init__(loop=loop)
if queue:
self._queue = weakref.proxy(queue)
self.add_done_callback(SmartFuture._self_done_cleanup_callback)
if key:
self._key = key
self._waiters = weakref.WeakSet()
self.add_done_callback(SmartFuture._self_done_cleanup_callback)
self._waiters = WeakSet()

def __repr__(self):
return f"<{<str>type(self).__name__} key={self._key} waiters={count_waiters(self)} {<str>self._state}>"
Expand All @@ -288,6 +266,73 @@ class SmartFuture(_SmartFutureMixin[T], asyncio.Future):
"""
return count_waiters(self) > count_waiters(other)

def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]:
"""
Await the smart future or task, handling waiters and logging.

Yields:
The result of the future or task.

Raises:
RuntimeError: If await wasn't used with future.

Example:
Awaiting a SmartFuture:

```python
future = SmartFuture()
result = await future
```

Awaiting a SmartTask:

```python
task = SmartTask(coro=my_coroutine())
result = await task
```
"""
if _is_done(self):
return _get_result(self) # May raise too.

self._asyncio_future_blocking = True
if current_task := asyncio.current_task(self._loop):
(<WeakSet>self._waiters).add(current_task)
current_task.add_done_callback(
self._waiter_done_cleanup_callback # type: ignore [union-attr]
)

logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
if _is_not_done(self):
raise RuntimeError("await wasn't used with future")
return _get_result(self) # May raise too.

def _waiter_done_cleanup_callback(
self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask"
) -> None:
"""
Callback to clean up waiters when a waiter task is done.

Removes the waiter from _waiters, and _queue._futs if applicable.

Args:
waiter: The waiter task to clean up.

Example:
Automatically called when a waiter task completes.
"""
if _is_not_done(self):
(<WeakSet>self._waiters).remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
"""
Callback to clean up waiters and remove the future from the queue when done.

This method clears all waiters and removes the future from the associated queue.
"""
if queue := self._queue:
queue._futs.pop(self._key)


def create_future(
*,
Expand Down Expand Up @@ -366,6 +411,74 @@ class SmartTask(_SmartFutureMixin[T], asyncio.Task):
self._waiters: Set["asyncio.Task[T]"] = <set>set()
self.add_done_callback(SmartTask._self_done_cleanup_callback)

def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]:
"""
Await the smart future or task, handling waiters and logging.

Yields:
The result of the future or task.

Raises:
RuntimeError: If await wasn't used with future.

Example:
Awaiting a SmartFuture:

```python
future = SmartFuture()
result = await future
```

Awaiting a SmartTask:

```python
task = SmartTask(coro=my_coroutine())
result = await task
```
"""
if _is_done(self):
return _get_result(self) # May raise too.

self._asyncio_future_blocking = True
if current_task := asyncio.current_task(self._loop):
(<set>self._waiters).add(current_task)
current_task.add_done_callback(
self._waiter_done_cleanup_callback # type: ignore [union-attr]
)

logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
if _is_not_done(self):
raise RuntimeError("await wasn't used with future")
return _get_result(self) # May raise too.

def _waiter_done_cleanup_callback(
self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask"
) -> None:
"""
Callback to clean up waiters when a waiter task is done.

Removes the waiter from _waiters, and _queue._futs if applicable.

Args:
waiter: The waiter task to clean up.

Example:
Automatically called when a waiter task completes.
"""
if _is_not_done(self):
(<set>self._waiters).remove(waiter)

def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None:
"""
Callback to clean up waiters and remove the future from the queue when done.

This method clears all waiters and removes the future from the associated queue.
"""
(<set>self._waiters).clear()
if queue := self._queue:
queue._futs.pop(self._key)


def smart_task_factory(loop: asyncio.AbstractEventLoop, coro: Awaitable[T]) -> SmartTask[T]:
"""
Expand Down Expand Up @@ -494,7 +607,7 @@ def shield(
if exc is not None:
outer.set_exception(exc)
else:
outer.set_result(_get_result(inner))
_set_result(outer, inner)

def _outer_done_callback(outer):
if _is_not_done(inner):
Expand All @@ -504,6 +617,18 @@ def shield(
outer.add_done_callback(_outer_done_callback)
return outer

cdef void _set_result(outer: asyncio.Future, inner: asyncio.Future):
"""Mark the future done and set its result.
If the future is already done when this method is called, raises
InvalidStateError.
"""
if <str>outer._state != "PENDING":
raise asyncio.exceptions.InvalidStateError(f'{outer._state}: {outer!r}')
outer._result = _get_result(inner)
outer._state = "FINISHED"
outer._Future__schedule_callbacks()


__all__ = [
"create_future",
Expand Down

0 comments on commit e6fab0d

Please sign in to comment.