From 36bcb63eadd57589f91f31fbb21f794334ffd1cd Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Tue, 23 Apr 2024 23:25:23 +0000 Subject: [PATCH] chore: refactor --- a_sync/_smart.py | 42 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/a_sync/_smart.py b/a_sync/_smart.py index 970461c0..6d0cd7a8 100644 --- a/a_sync/_smart.py +++ b/a_sync/_smart.py @@ -16,6 +16,21 @@ logger = logging.getLogger(__name__) class _SmartFutureMixin(Generic[T]): + _queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None + def __await__(self): + logger.debug("entering %s", self) + if self.done(): + return self.result() # May raise too. + self._asyncio_future_blocking = True + self._waiters.add(current_task := asyncio.current_task(self._loop)) + logger.debug("awaiting %s", self) + yield self # This tells Task to wait for completion. + self._waiters.remove(current_task) + if self._queue and not self._waiters: + self._queue._futs.pop(self._key) + if not self.done(): + raise RuntimeError("await wasn't used with future") + return self.result() # May raise too. @property def num_waiters(self) -> int: return sum(getattr(waiter, 'num_waiters', 1) for waiter in self._waiters) @@ -30,37 +45,12 @@ def __init__(self, queue: "SmartProcessingQueue", key: _Key, *, loop: Optional[a self._key = key def __repr__(self): return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>" - def __await__(self): - logger.debug("entering %s", self) - if self.done(): - return self.result() # May raise too. - self._asyncio_future_blocking = True - self._waiters.add(current_task := asyncio.current_task(self._loop)) - logger.debug("awaiting %s", self) - yield self # This tells Task to wait for completion. - self._waiters.remove(current_task) - if self.num_waiters == 0: - self._queue._futs.pop(self._key) - if not self.done(): - raise RuntimeError("await wasn't used with future") - return self.result() # May raise too. def __lt__(self, other: "SmartFuture") -> bool: """heap considers lower values as higher priority so a future with more waiters will be 'less than' a future with less waiters.""" return self.num_waiters > other.num_waiters class SmartTask(_SmartFutureMixin[T], asyncio.Task): - def __await__(self): - logger.debug("entering %s", self) - if self.done(): - return self.result() # May raise too. - self._asyncio_future_blocking = True - self._waiters.add(current_task := asyncio.current_task(self._loop)) - logger.debug("awaiting %s", self) - yield self # This tells Task to wait for completion. - self._waiters.remove(current_task) - if not self.done(): - raise RuntimeError("await wasn't used with future") - return self.result() # May raise too. + ... def smart_task_factory(loop: asyncio.AbstractEventLoop, coro: Awaitable[T]) -> SmartTask[T]: return SmartTask(coro, loop=loop)