Skip to content

Commit

Permalink
feat: Queue.get with multiple results (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Feb 24, 2024
1 parent 6ab59e4 commit c02c9a6
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions a_sync/primitives/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,37 @@ class _Queue(asyncio.Queue[T]):...
class Queue(_Queue[T]):
"""The only difference between an a_sync.Queue and an asyncio.Queue is that `get_nowait` can retrn multiple responses."""
@overload
def get_nowait(self, i: int = 1, can_return_less: Literal[False] = False) -> T:
async def get(self, i: int = 1, can_return_less: bool = False) -> T:
...
@overload
def get_nowait(self, i: int = 1, can_return_less: Literal[True] = False) -> List[T]:
async def get(self, i: int, can_return_less: bool = False) -> List[T]:
...
async def get(self, i: int = 1, can_return_less: bool = False) -> Union[T, List[T]]:
_validate_args(i, can_return_less)
if i == 1:
return await super().get()
try:
items = self.get_nowait(i, can_return_less=True)
except asyncio.QueueEmpty:
items = [await super().get()]
if len(items) == i or can_return_less:
return items
while len(items) < i:
items.extend(await self.get(i - len(items)))
return items

@overload
def get_nowait(self, i: int = 1, can_return_less: bool = False) -> T:
...
@overload
def get_nowait(self, i: int, can_return_less: bool = False) -> List[T]:
...
def get_nowait(self, i: int = 1, can_return_less: bool = False) -> Union[T, List[T]]:
"""
Just like `asyncio.Queue.get_nowait`, but will return `i` items instead of 1.
Set `can_return_less` to True if you want to receive up to `i` items.
"""
if can_return_less and i == 1:
raise ValueError("you cant set i == 1 with can_return_less == True")
_validate_args(i, can_return_less)
values = []
if i == -1:
while True:
Expand All @@ -36,5 +55,17 @@ def get_nowait(self, i: int = 1, can_return_less: bool = False) -> Union[T, List
except asyncio.QueueEmpty:
if can_return_less:
break
for value in values:
self.put_nowait(value)
raise
return values[0] if i == 1 else values
return values[0] if i == 1 else values

def _validate_args(i: int, can_return_less: bool) -> None:
if not isinstance(i, int):
raise TypeError(f"`i` must be a non-zero integer. You passed {i}")
if not isinstance(can_return_less, bool):
raise TypeError(f"`can_return_less` must be boolean. You passed {can_return_less}")
if i == 0:
raise ValueError(f"`i` must be a non-zero integer. You passed {i}")
if can_return_less and i == 1:
raise ValueError("you can't set i == 1 with can_return_less == True")

0 comments on commit c02c9a6

Please sign in to comment.