Skip to content

Commit

Permalink
feat: SmartProcessingQueue (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Apr 23, 2024
1 parent 2f8cd12 commit cdb0ab4
Showing 1 changed file with 160 additions and 9 deletions.
169 changes: 160 additions & 9 deletions a_sync/primitives/queue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import heapq
import logging
import sys

Expand All @@ -17,13 +18,14 @@ class _Queue(asyncio.Queue[T]):
class Queue(_Queue[T]):
# for type hint support, no functional difference
async def get(self) -> T:
return await super().get()
self._queue
return await _Queue.get(self)
def get_nowait(self) -> T:
return super().get_nowait()
return _Queue.get_nowait(self)
async def put(self, item: T) -> None:
return super().put(item)
return _Queue.put(self, item)
def put_nowait(self, item: T) -> None:
return super().put_nowait(item)
return _Queue.put_nowait(self, item)

async def get_all(self) -> List[T]:
"""returns 1 or more items"""
Expand Down Expand Up @@ -100,19 +102,24 @@ def __del__(self) -> None:
}
asyncio.get_event_loop().call_exception_handler(context)
async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
self._workers
self._ensure_workers()
if self._no_futs:
return await super().put((args, kwargs))
fut = asyncio.get_event_loop().create_future()
fut = self._create_future()
await super().put((args, kwargs, fut))
return fut
def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
self._workers
self._ensure_workers()
if self._no_futs:
return super().put_nowait((args, kwargs))
fut = asyncio.get_event_loop().create_future()
fut = self._create_future()
super().put_nowait((args, kwargs, fut))
return fut
def _create_future(self) -> "asyncio.Future[V]":
return asyncio.get_event_loop().create_future()
def _ensure_workers(self) -> None:
if self._workers.done():
raise self._workers.exception()
@functools.cached_property
def _workers(self) -> "asyncio.Task[NoReturn]":
from a_sync.task import create_task
Expand All @@ -139,7 +146,15 @@ async def _worker_coro(self) -> NoReturn:
args, kwargs, fut = await self.get()
fut.set_result(await self.func(*args, **kwargs))
except Exception as e:
fut.set_result(e)
try:
fut.set_exception(e)
except UnboundLocalError as u:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
if str(e) != "local variable 'fut' referenced before assignment":
logger.exception(u)
raise u
logger.exception(e)
raise e
self.task_done()


Expand All @@ -150,3 +165,139 @@ def _validate_args(i: int, can_return_less: bool) -> None:
raise TypeError(f"`can_return_less` must be boolean. You passed {can_return_less}")
if i <= 1:
raise ValueError(f"`i` must be an integer greater than 1. You passed {i}")


_Args = Tuple[Any]
_Kwargs = Tuple[Tuple[str, Any]]
_Key = Tuple[_Args, _Kwargs]

class SmartFuture(asyncio.Future, Generic[T]):
def __init__(self, queue: "SmartProcessingQueue", key: _Key, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)
self._queue = queue
self._key = key
self._waiters: Set["asyncio.Task[T]"] = set()
def __repr__(self):
return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>"
def __await__(self):
logger.debug("entering %s", self)
if self.done():
return self.result() # May raise too.
self._asyncio_future_blocking = True
self._waiters.add(current_task := asyncio.current_task(self._loop))
logger.debug("awaiting %s", self)
yield self # This tells Task to wait for completion.
self._waiters.remove(current_task)
if self.num_waiters == 0:
self._queue._futs.pop(self._key)
if not self.done():
raise RuntimeError("await wasn't used with future")
return self.result() # May raise too.
def __lt__(self, other: "SmartFuture") -> bool:
"""heap considers lower values as higher priority so a future with more waiters will be 'less than' a future with less waiters."""
return self.num_waiters > other.num_waiters
@property
def num_waiters(self) -> int:
return len(self._waiters)


class _PriorityQueueMixin(Generic[T]):
def _init(self, maxsize):
self._queue: List[T] = []
def _put(self, item, heappush=heapq.heappush):
heappush(self._queue, item)
def _get(self, heappop=heapq.heappop):
return heappop(self._queue)

class PriorityProcessingQueue(_PriorityQueueMixin[T], ProcessingQueue[T, V]):
# NOTE: WIP
async def put(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
self._ensure_workers()
fut = asyncio.get_event_loop().create_future()
await super().put(self, (priority, args, kwargs, fut))
return fut
def put_nowait(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":
self._ensure_workers()
fut = self._create_future()
super().put_nowait(self, (priority, args, kwargs, fut))
return fut
def _get(self, heappop=heapq.heappop):
priority, args, kwargs, fut = heappop(self._queue)
return args, kwargs, fut

class _VariablePriorityQueueMixin(_PriorityQueueMixin[T]):
def _get(self, heapify=heapq.heapify, heappop=heapq.heappop):
"Resort the heap to consider any changes in priorities and pop the smallest value"
# resort the heap
heapify(self._queue)
# take the job with the most waiters
return heappop(self._queue)
def _get_key(self, *args, **kwargs) -> _Key:
return (args, tuple((kwarg, kwargs[kwarg]) for kwarg in sorted(kwargs)))
def _create_future(self, key: _Key) -> "asyncio.Future[V]":
return SmartFuture(self, key, loop=asyncio.get_event_loop())

class VariablePriorityQueue(_VariablePriorityQueueMixin[T], asyncio.PriorityQueue):
"""A PriorityQueue subclass that allows priorities to be updated (or computed) on the fly"""
# NOTE: WIP

class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V]):
"""A PriorityProcessingQueue subclass that will execute jobs with the most waiters first"""
_no_futs = False
def __init__(
self,
func: Callable[Concatenate[T, P], Awaitable[V]],
num_workers: int,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(func, num_workers, return_data=True, loop=loop)
self._futs: Dict[_Key[T], SmartFuture[T]] = {}
async def put(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
self._workers
key = self._get_key(*args, **kwargs)
if fut := self._futs.get(key, None):
logger.info("using cached fut")
return fut
fut = self._create_future(key)
self._futs[key] = fut
await Queue.put(self, (fut, args, kwargs))
return fut
def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
self._workers
key = self._get_key(*args, **kwargs)
if fut := self._futs.get(key, None):
logger.info("using cached fut")
return fut
fut = self._create_future(key)
self._futs[key] = fut
Queue.put_nowait(self, (fut, args, kwargs))
return fut
def _get(self):
fut, args, kwargs = super()._get()
return args, kwargs, fut
async def _worker_coro(self) -> NoReturn:
args: P.args
kwargs: P.kwargs
fut: SmartFuture[V]
while True:
try:
args, kwargs, fut = await self.get()
if fut.num_waiters > 1:
logger.info("processing %s", fut)
else:
logger.debug("processing %s", fut)
result = await self.func(*args, **kwargs)
fut.set_result(result)
except Exception as e:
try:
logger.info("%s: %s", type(e).__name__, e)
fut.set_exception(e)
except UnboundLocalError as u:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
if str(e) != "local variable 'fut' referenced before assignment":
logger.exception(u)
raise u
logger.exception(e)
raise e
self.task_done()

0 comments on commit cdb0ab4

Please sign in to comment.