From b6ae00d41d235c772384481768fec5aabf760e4f Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Mon, 22 Apr 2024 19:31:52 +0000 Subject: [PATCH] feat: cache SmartProcessingQueue futures --- a_sync/primitives/queue.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/a_sync/primitives/queue.py b/a_sync/primitives/queue.py index d1afeaa1..b581619e 100644 --- a/a_sync/primitives/queue.py +++ b/a_sync/primitives/queue.py @@ -105,7 +105,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._workers if self._no_futs: return await super().put((args, kwargs)) - fut = asyncio.get_event_loop().create_future() + fut = self._create_future() await super().put((args, kwargs, fut)) return fut def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": @@ -165,8 +165,11 @@ def _validate_args(i: int, can_return_less: bool) -> None: class SmartFuture(asyncio.Future, Generic[T]): - # classvar holds default value for instances - _waiters: Set["asyncio.Task[T]"] = set() + def __init__(self, queue: "SmartProcessingQueue", key: "_Key", *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(loop=loop) + self._queue = queue + self._key = key + self._waiters: Set["asyncio.Task[T]"] = set() def __repr__(self): return f"<{type(self).__name__} waiters={self.num_waiters} {self._state}>" def __await__(self): @@ -179,6 +182,8 @@ def __await__(self): logger.info("%s waiters: %s", self, self._waiters) yield self # This tells Task to wait for completion. self._waiters.remove(current_task) + if self.num_waiters == 0: + self._queue._futs.pop(self._key) if not self.done(): raise RuntimeError("await wasn't used with future") return self.result() # May raise too. @@ -220,12 +225,16 @@ def _get(self, heapify=heapq.heapify, heappop=heapq.heappop): heapify(self._queue) # take the job with the most waiters return heappop(self._queue) - def _create_future(self) -> "asyncio.Future[V]": - return SmartFuture(loop=asyncio.get_event_loop()) + def _get_key(self, *args, **kwargs) -> "_Key": + return (args, tuple((kwarg, kwargs[kwarg]) for kwarg in sorted(kwargs))) + def _create_future(self, key: "_Key") -> "asyncio.Future[V]": + return SmartFuture(key, loop=asyncio.get_event_loop()) class VariablePriorityQueue(_VariablePriorityQueueMixin[T], asyncio.PriorityQueue): """A PriorityQueue subclass that allows priorities to be updated (or computed) on the fly""" +_Key = Tuple[Tuple[Any], Tuple[Tuple[str, Any]]] + class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V]): """A PriorityProcessingQueue subclass that will execute jobs with the most waiters first""" _no_futs = False @@ -237,14 +246,23 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__(func, num_workers, return_data=True, loop=loop) + self._futs: Dict[_Key[T], SmartFuture[T]] = {} async def put(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]: self._workers - fut = asyncio.get_event_loop().create_future() + key = self._get_key(*args, **kwargs) + if fut := self._futs.get(key, None): + return fut + fut = self._create_future(key) + self._futs[key] = fut await Queue.put(self, (fut, args, kwargs)) return fut def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]: self._workers - fut = self._create_future() + key = self._get_key(*args, **kwargs) + if fut := self._futs.get(key, None): + return fut + fut = self._create_future(key) + self._futs[key] = fut Queue.put_nowait(self, (fut, args, kwargs)) return fut def _get(self):