Skip to content

Commit

Permalink
feat: cache SmartProcessingQueue futures
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Apr 22, 2024
1 parent 5e796ec commit b6ae00d
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions a_sync/primitives/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
self._workers
if self._no_futs:
return await super().put((args, kwargs))
fut = asyncio.get_event_loop().create_future()
fut = self._create_future()
await super().put((args, kwargs, fut))
return fut
def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
Expand Down Expand Up @@ -165,8 +165,11 @@ def _validate_args(i: int, can_return_less: bool) -> None:


class SmartFuture(asyncio.Future, Generic[T]):
# classvar holds default value for instances
_waiters: Set["asyncio.Task[T]"] = set()
def __init__(self, queue: "SmartProcessingQueue", key: "_Key", *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)
self._queue = queue
self._key = key
self._waiters: Set["asyncio.Task[T]"] = set()
def __repr__(self):
return f"<{type(self).__name__} waiters={self.num_waiters} {self._state}>"
def __await__(self):
Expand All @@ -179,6 +182,8 @@ def __await__(self):
logger.info("%s waiters: %s", self, self._waiters)
yield self # This tells Task to wait for completion.
self._waiters.remove(current_task)
if self.num_waiters == 0:
self._queue._futs.pop(self._key)
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
Expand Down Expand Up @@ -220,12 +225,16 @@ def _get(self, heapify=heapq.heapify, heappop=heapq.heappop):
heapify(self._queue)
# take the job with the most waiters
return heappop(self._queue)
def _create_future(self) -> "asyncio.Future[V]":
return SmartFuture(loop=asyncio.get_event_loop())
def _get_key(self, *args, **kwargs) -> "_Key":
return (args, tuple((kwarg, kwargs[kwarg]) for kwarg in sorted(kwargs)))
def _create_future(self, key: "_Key") -> "asyncio.Future[V]":
return SmartFuture(key, loop=asyncio.get_event_loop())

class VariablePriorityQueue(_VariablePriorityQueueMixin[T], asyncio.PriorityQueue):
"""A PriorityQueue subclass that allows priorities to be updated (or computed) on the fly"""

_Key = Tuple[Tuple[Any], Tuple[Tuple[str, Any]]]

class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V]):
"""A PriorityProcessingQueue subclass that will execute jobs with the most waiters first"""
_no_futs = False
Expand All @@ -237,14 +246,23 @@ def __init__(
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(func, num_workers, return_data=True, loop=loop)
self._futs: Dict[_Key[T], SmartFuture[T]] = {}
async def put(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
self._workers
fut = asyncio.get_event_loop().create_future()
key = self._get_key(*args, **kwargs)
if fut := self._futs.get(key, None):
return fut
fut = self._create_future(key)
self._futs[key] = fut
await Queue.put(self, (fut, args, kwargs))
return fut
def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
self._workers
fut = self._create_future()
key = self._get_key(*args, **kwargs)
if fut := self._futs.get(key, None):
return fut
fut = self._create_future(key)
self._futs[key] = fut
Queue.put_nowait(self, (fut, args, kwargs))
return fut
def _get(self):
Expand Down

0 comments on commit b6ae00d

Please sign in to comment.