Skip to content

Commit

Permalink
chore: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Apr 23, 2024
1 parent 221cbe2 commit 36bcb63
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions a_sync/_smart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
logger = logging.getLogger(__name__)

class _SmartFutureMixin(Generic[T]):
_queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None
def __await__(self):
logger.debug("entering %s", self)
if self.done():
return self.result() # May raise too.
self._asyncio_future_blocking = True
self._waiters.add(current_task := asyncio.current_task(self._loop))
logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
self._waiters.remove(current_task)
if self._queue and not self._waiters:
self._queue._futs.pop(self._key)
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
@property
def num_waiters(self) -> int:
return sum(getattr(waiter, 'num_waiters', 1) for waiter in self._waiters)
Expand All @@ -30,37 +45,12 @@ def __init__(self, queue: "SmartProcessingQueue", key: _Key, *, loop: Optional[a
self._key = key
def __repr__(self):
return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>"
def __await__(self):
logger.debug("entering %s", self)
if self.done():
return self.result() # May raise too.
self._asyncio_future_blocking = True
self._waiters.add(current_task := asyncio.current_task(self._loop))
logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
self._waiters.remove(current_task)
if self.num_waiters == 0:
self._queue._futs.pop(self._key)
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
def __lt__(self, other: "SmartFuture") -> bool:
"""heap considers lower values as higher priority so a future with more waiters will be 'less than' a future with less waiters."""
return self.num_waiters > other.num_waiters

class SmartTask(_SmartFutureMixin[T], asyncio.Task):
def __await__(self):
logger.debug("entering %s", self)
if self.done():
return self.result() # May raise too.
self._asyncio_future_blocking = True
self._waiters.add(current_task := asyncio.current_task(self._loop))
logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
self._waiters.remove(current_task)
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
...

def smart_task_factory(loop: asyncio.AbstractEventLoop, coro: Awaitable[T]) -> SmartTask[T]:
return SmartTask(coro, loop=loop)
Expand Down

0 comments on commit 36bcb63

Please sign in to comment.