From e6fab0d92861d0950b23ac93532bbdf35dc00b20 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Thu, 12 Dec 2024 07:37:23 +0000 Subject: [PATCH 1/2] feat: optimize SmartFuture --- Makefile | 2 +- a_sync/_smart.pyx | 269 +++++++++++++++++++++++++++++++++------------- 2 files changed, 198 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index 0e2a1d64..ea5bee19 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ docs: sphinx-apidoc --private -o ./docs/source ./a_sync cython: - python csetup.py build_ext --inplace + python setup.py build_ext --inplace stubs: stubgen ./a_sync -o . --include-docstrings \ No newline at end of file diff --git a/a_sync/_smart.pyx b/a_sync/_smart.pyx index edac1ece..99e7c894 100644 --- a/a_sync/_smart.pyx +++ b/a_sync/_smart.pyx @@ -9,6 +9,7 @@ import asyncio import logging import warnings import weakref +from libc.stdint cimport uintptr_t import a_sync.asyncio from a_sync._typing import * @@ -57,47 +58,6 @@ class _SmartFutureMixin(Generic[T]): _key: _Key _waiters: "weakref.WeakSet[SmartTask[T]]" - def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]: - """ - Await the smart future or task, handling waiters and logging. - - Yields: - The result of the future or task. - - Raises: - RuntimeError: If await wasn't used with future. - - Example: - Awaiting a SmartFuture: - - ```python - future = SmartFuture() - result = await future - ``` - - Awaiting a SmartTask: - - ```python - task = SmartTask(coro=my_coroutine()) - result = await task - ``` - """ - if _is_done(self): - return _get_result(self) # May raise too. - - self._asyncio_future_blocking = True - if current_task := asyncio.current_task(self._loop): - self._waiters.add(current_task) - current_task.add_done_callback( - self._waiter_done_cleanup_callback # type: ignore [union-attr] - ) - - logger.debug("awaiting %s", self) - yield self # This tells Task to wait for completion. - if _is_not_done(self): - raise RuntimeError("await wasn't used with future") - return _get_result(self) # May raise too. - @property def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> Py_ssize_t: """ @@ -117,35 +77,9 @@ class _SmartFutureMixin(Generic[T]): """ return count_waiters(self) - def _waiter_done_cleanup_callback( - self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask" - ) -> None: - """ - Callback to clean up waiters when a waiter task is done. - - Removes the waiter from _waiters, and _queue._futs if applicable. - - Args: - waiter: The waiter task to clean up. - - Example: - Automatically called when a waiter task completes. - """ - if _is_not_done(self): - self._waiters.remove(waiter) - - def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None: - """ - Callback to clean up waiters and remove the future from the queue when done. - - This method clears all waiters and removes the future from the associated queue. - """ - self._waiters.clear() - if queue := self._queue: - queue._futs.pop(self._key) - cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]): + cdef WeakSet waiters if _is_done(fut): return ZERO try: @@ -153,11 +87,55 @@ cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]): except AttributeError: return ONE cdef Py_ssize_t count = 0 - for waiter in waiters: + for waiter in waiters.iter(): count += count_waiters(waiter) return count +cdef class WeakSet: + _refs: dict[uintptr_t, object] + """Mapping from object ID to weak reference.""" + + def __cinit__(self): + self._refs = {} + + def _gc_callback(self, fut: asyncio.Future) -> None: + # Callback when a weakly-referenced object is garbage collected + self._refs.pop(id(fut), None) # Safely remove the item if it exists + + cdef void add(self, fut: asyncio.Future): + # Keep a weak reference with a callback for when the item is collected + ref = weakref.ref(fut, self._gc_callback) + self._refs[id(fut)] = ref + + cdef void remove(self, fut: asyncio.Future): + # Keep a weak reference with a callback for when the item is collected + try: + self._refs.pop(id(fut)) + except KeyError: + raise KeyError(fut) from None + + def __len__(self) -> int: + return len(self._refs) + + def __bool__(self) -> bool: + return bool(self._refs) + + def __contains__(self, item: asyncio.Future) -> bool: + ref = self._refs.get(id(item)) + return ref is not None and ref() is item + + def __iter__(self): + for ref in self._refs.values(): + item = ref() + if item is not None: + yield item + + def __repr__(self): + # Use list comprehension syntax within the repr function for clarity + return f"WeakSet({', '.join(repr(item) for item in self)})" + + cdef inline bint _is_done(fut: asyncio.Future): """Return True if the future is done. @@ -260,10 +238,10 @@ class SmartFuture(_SmartFutureMixin[T], asyncio.Future): super().__init__(loop=loop) if queue: self._queue = weakref.proxy(queue) + self.add_done_callback(SmartFuture._self_done_cleanup_callback) if key: self._key = key - self._waiters = weakref.WeakSet() - self.add_done_callback(SmartFuture._self_done_cleanup_callback) + self._waiters = WeakSet() def __repr__(self): return f"<{type(self).__name__} key={self._key} waiters={count_waiters(self)} {self._state}>" @@ -288,6 +266,73 @@ class SmartFuture(_SmartFutureMixin[T], asyncio.Future): """ return count_waiters(self) > count_waiters(other) + def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]: + """ + Await the smart future or task, handling waiters and logging. + + Yields: + The result of the future or task. + + Raises: + RuntimeError: If await wasn't used with future. + + Example: + Awaiting a SmartFuture: + + ```python + future = SmartFuture() + result = await future + ``` + + Awaiting a SmartTask: + + ```python + task = SmartTask(coro=my_coroutine()) + result = await task + ``` + """ + if _is_done(self): + return _get_result(self) # May raise too. + + self._asyncio_future_blocking = True + if current_task := asyncio.current_task(self._loop): + (self._waiters).add(current_task) + current_task.add_done_callback( + self._waiter_done_cleanup_callback # type: ignore [union-attr] + ) + + logger.debug("awaiting %s", self) + yield self # This tells Task to wait for completion. + if _is_not_done(self): + raise RuntimeError("await wasn't used with future") + return _get_result(self) # May raise too. + + def _waiter_done_cleanup_callback( + self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask" + ) -> None: + """ + Callback to clean up waiters when a waiter task is done. + + Removes the waiter from _waiters, and _queue._futs if applicable. + + Args: + waiter: The waiter task to clean up. + + Example: + Automatically called when a waiter task completes. + """ + if _is_not_done(self): + (self._waiters).remove(waiter) + + def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None: + """ + Callback to clean up waiters and remove the future from the queue when done. + + This method clears all waiters and removes the future from the associated queue. + """ + if queue := self._queue: + queue._futs.pop(self._key) + def create_future( *, @@ -366,6 +411,74 @@ class SmartTask(_SmartFutureMixin[T], asyncio.Task): self._waiters: Set["asyncio.Task[T]"] = set() self.add_done_callback(SmartTask._self_done_cleanup_callback) + def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]: + """ + Await the smart future or task, handling waiters and logging. + + Yields: + The result of the future or task. + + Raises: + RuntimeError: If await wasn't used with future. + + Example: + Awaiting a SmartFuture: + + ```python + future = SmartFuture() + result = await future + ``` + + Awaiting a SmartTask: + + ```python + task = SmartTask(coro=my_coroutine()) + result = await task + ``` + """ + if _is_done(self): + return _get_result(self) # May raise too. + + self._asyncio_future_blocking = True + if current_task := asyncio.current_task(self._loop): + (self._waiters).add(current_task) + current_task.add_done_callback( + self._waiter_done_cleanup_callback # type: ignore [union-attr] + ) + + logger.debug("awaiting %s", self) + yield self # This tells Task to wait for completion. + if _is_not_done(self): + raise RuntimeError("await wasn't used with future") + return _get_result(self) # May raise too. + + def _waiter_done_cleanup_callback( + self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask" + ) -> None: + """ + Callback to clean up waiters when a waiter task is done. + + Removes the waiter from _waiters, and _queue._futs if applicable. + + Args: + waiter: The waiter task to clean up. + + Example: + Automatically called when a waiter task completes. + """ + if _is_not_done(self): + (self._waiters).remove(waiter) + + def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None: + """ + Callback to clean up waiters and remove the future from the queue when done. + + This method clears all waiters and removes the future from the associated queue. + """ + (self._waiters).clear() + if queue := self._queue: + queue._futs.pop(self._key) + def smart_task_factory(loop: asyncio.AbstractEventLoop, coro: Awaitable[T]) -> SmartTask[T]: """ @@ -494,7 +607,7 @@ def shield( if exc is not None: outer.set_exception(exc) else: - outer.set_result(_get_result(inner)) + _set_result(outer, inner) def _outer_done_callback(outer): if _is_not_done(inner): @@ -504,6 +617,18 @@ def shield( outer.add_done_callback(_outer_done_callback) return outer +cdef void _set_result(outer: asyncio.Future, inner: asyncio.Future): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if outer._state != "PENDING": + raise asyncio.exceptions.InvalidStateError(f'{outer._state}: {outer!r}') + outer._result = _get_result(inner) + outer._state = "FINISHED" + outer._Future__schedule_callbacks() + __all__ = [ "create_future", From 11498ba961b97326e77a18a2712cdbc91f1e1e30 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Thu, 12 Dec 2024 07:43:15 +0000 Subject: [PATCH 2/2] fix: _set_result --- a_sync/_smart.pyx | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/a_sync/_smart.pyx b/a_sync/_smart.pyx index 99e7c894..9d6b23f5 100644 --- a/a_sync/_smart.pyx +++ b/a_sync/_smart.pyx @@ -607,7 +607,7 @@ def shield( if exc is not None: outer.set_exception(exc) else: - _set_result(outer, inner) + outer.set_result(_get_result(inner)) def _outer_done_callback(outer): if _is_not_done(inner): @@ -617,18 +617,6 @@ def shield( outer.add_done_callback(_outer_done_callback) return outer -cdef void _set_result(outer: asyncio.Future, inner: asyncio.Future): - """Mark the future done and set its result. - - If the future is already done when this method is called, raises - InvalidStateError. - """ - if outer._state != "PENDING": - raise asyncio.exceptions.InvalidStateError(f'{outer._state}: {outer!r}') - outer._result = _get_result(inner) - outer._state = "FINISHED" - outer._Future__schedule_callbacks() - __all__ = [ "create_future",