Skip to content

Commit

Permalink
chore: refactor iter module (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Mar 1, 2024
1 parent 330e762 commit 6f4ffdc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 50 deletions.
61 changes: 42 additions & 19 deletions a_sync/iter.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@

import asyncio
import functools
import inspect
import logging
from a_sync._typing import *


logger = logging.getLogger(__name__)

class ASyncIterable(AsyncIterable[T], Iterable[T]):
"""An abstract iterable implementation that can be used in both a `for` loop and an `async for` loop."""
"""A hybrid Iterable/AsyncIterable implementation that can be used in both a `for` loop and an `async for` loop."""
def __iter__(self) -> Iterator[T]:
yield from ASyncIterator.wrap(self.__aiter__())
yield from ASyncIterator(self.__aiter__())
@classmethod
def wrap(self, aiterable: AsyncIterable[T]) -> "ASyncWrappedIterable[T]":
return ASyncWrappedIterable(aiterable)
def wrap(cls, wrapped: AsyncIterable[T]) -> "ASyncIterable[T]":
# NOTE: for backward-compatability only. Will be removed soon.
logger.warning("ASyncIterable.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterable(wrapped)`")
return cls(wrapped)
def __init__(self, async_iterable: AsyncIterable[T]):
self.__wrapped__ = async_iterable
def __aiter__(self) -> AsyncIterator[T]:
return self.__wrapped__.__aiter__()
__slots__ = "__wrapped__",

AsyncGenFunc = Callable[P, AsyncGenerator[T, None]]

class ASyncIterator(AsyncIterator[T], Iterator[T]):
"""An abstract iterator implementation that can be used in both a `for` loop and an `async for` loop."""
"""A hybrid Iterator/AsyncIterator implementation that can be used in both a `for` loop and an `async for` loop."""
def __next__(self) -> T:
try:
return asyncio.get_event_loop().run_until_complete(self.__anext__())
except StopAsyncIteration as e:
raise StopIteration from e
@overload
def wrap(cls, aiterator: AsyncIterator[T]) -> "ASyncIterator[T]":...
@overload
def wrap(cls, async_gen_func: AsyncGenFunc[P, T]) -> "ASyncGeneratorFunction[P, T]":...
@classmethod
def wrap(self, aiterator: AsyncIterator[T]) -> "ASyncWrappedIterator[T]":
return ASyncWrappedIterator(aiterator)

class ASyncWrappedIterable(ASyncIterable[T]):
__slots__ = "__aiterable",
def __init__(self, async_iterable: AsyncIterable[T]):
self.__aiterable = async_iterable
def __aiter__(self) -> AsyncIterator[T]:
return self.__aiterable.__aiter__()

class ASyncWrappedIterator(ASyncIterator[T]):
__slots__ = "__aiterator",
def wrap(cls, wrapped):
if isinstance(wrapped, AsyncIterator):
logger.warning("This use case for ASyncIterator.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterator(wrapped)`")
return cls(wrapped)
elif inspect.isasyncgenfunction(wrapped):
return ASyncGeneratorFunction(wrapped)
raise TypeError(f"`wrapped` must be an AsyncIterator or an async generator function. You passed {wrapped}")
def __init__(self, async_iterator: AsyncIterator[T]):
self.__aiterator = async_iterator
self.__wrapped__ = async_iterator
async def __anext__(self) -> T:
return await self.__aiterator.__anext__()
return await self.__wrapped__.__anext__()

class ASyncGeneratorFunction(Generic[P, T]):
__slots__ = "__wrapped__",
def __init__(self, async_gen_func: AsyncGenFunc[P, T]) -> None:
self.__wrapped__ = async_gen_func
functools.update_wrapper(self, self.__wrapped__)
assert inspect.isasyncgenfunction(self)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ASyncIterator[T]:
return ASyncIterator(self.__wrapped__(*args, **kwargs))

47 changes: 23 additions & 24 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,48 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until

MappingFn = Callable[Concatenate[K, P], Awaitable[V]]

class TaskMapping(ASyncIterable[Tuple[K, V]], DefaultDict[K, "asyncio.Task[V]"]):
__slots__ = "_coro_fn", "_coro_fn_kwargs", "_name", "_loader"
def __init__(self, coro_fn: MappingFn[K, P, V] = None, *iterables: AnyIterable[K], name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
class TaskMapping(ASyncIterable[Tuple[K, V]], Mapping[K, "asyncio.Task[V]"]):
__slots__ = "_wrapped_func", "_wrapped_func_kwargs", "_name", "_tasks", "_init_loader"
def __init__(self, wrapped_func: MappingFn[K, P, V] = None, *iterables: AnyIterable[K], name: str = '', **wrapped_func_kwargs: P.kwargs) -> None:
self._wrapped_func = wrapped_func
self._wrapped_func_kwargs = wrapped_func_kwargs
self._name = name
self._loader: Optional["asyncio.Task[None]"]
self._tasks: Dict[K, "asyncio.Task[V]"] = {}
self._init_loader: Optional["asyncio.Task[None]"]
if iterables:
self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables)))
self._init_loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables)))
else:
self._loader = None
self._init_loader = None
def __setitem__(self, item: Any, value: Any) -> None:
raise NotImplementedError("You cannot manually set items in a TaskMapping")
def __getitem__(self, item: K) -> "asyncio.Task[V]":
try:
return super().__getitem__(item)
return self._tasks[item]
except KeyError:
task = create_task(
coro=self._coro_fn(item, **self._coro_fn_kwargs),
coro=self._wrapped_func(item, **self._wrapped_func_kwargs),
name=f"{self._name}[{item}]" if self._name else f"{item}",
)
super().__setitem__(item, task)
self._tasks[item] = task
return task
def __len__(self) -> int:
return len(self._tasks)
def __await__(self) -> Generator[Any, None, Dict[K, V]]:
"""await all tasks and returns a mapping with the results for each key"""
return self._await().__await__()
async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]:
"""aiterate thru all key-task pairs, yielding the key-result pair as each task completes"""
yielded = set()
# if you inited the TaskMapping with some iterators, we will load those
if self._loader:
while not self._loader.done():
if self._init_loader:
while not self._init_loader.done():
async for key, value in self.yield_completed(pop=False):
if key not in yielded:
yield _yield(key, value, "both")
yielded.add(key)
await asyncio.sleep(0)
# loader is already done by this point, but we need to check for exceptions
await self._loader
await self._init_loader
elif not self:
# if you didn't init the TaskMapping with iterators and you didn't start any tasks manually, we should fail
raise exceptions.MappingIsEmptyError
Expand All @@ -71,10 +74,6 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]:
async for key, value in as_completed(self, aiter=True):
if key not in yielded:
yield _yield(key, value, "both")
#def keys(self) -> KeysView[K]:
# if self._loader and not self._loader.done():
# raise RuntimeError("the loader needs time to complete. bob will figure out a way to make this not impact sync users")
# return super().keys()
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
if self:
raise exceptions.MappingNotEmptyError
Expand All @@ -83,19 +82,19 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera
async for _ in self._tasks_for_iterables(*iterables):
async for key, value in self.yield_completed(pop=pop):
yield _yield(key, value, yields)
async for key, value in as_completed(self, aiter=True):
async for key, value in as_completed(self._tasks, aiter=True):
if pop:
self.pop(key)
self._tasks.pop(key)
yield _yield(key, value, yields)
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
for k, task in dict(self._tasks).items():
if task.done():
if pop:
task = self.pop(k)
task = self._tasks.pop(k)
yield k, await task
async def _await(self) -> Dict[K, V]:
if self._loader:
await self._loader
if self._init_loader:
await self._init_loader
if not self:
raise exceptions.MappingIsEmptyError
return await gather(self)
Expand Down
11 changes: 7 additions & 4 deletions tests/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import pytest

from a_sync.iter import (ASyncIterable, ASyncIterator, ASyncWrappedIterable,
ASyncWrappedIterator)
from a_sync.iter import ASyncIterable, ASyncIterator


async def async_gen():
Expand All @@ -12,16 +11,20 @@ async def async_gen():
yield 2

def test_iterable_wrap():
assert isinstance(ASyncIterable.wrap(async_gen()), ASyncWrappedIterable)
assert isinstance(ASyncIterable(async_gen()), ASyncIterable)
assert isinstance(ASyncIterable.wrap(async_gen()), ASyncIterable)

def test_iterator_wrap():
assert isinstance(ASyncIterator.wrap(async_gen()), ASyncWrappedIterator)
assert isinstance(ASyncIterator(async_gen()), ASyncIterator)
assert isinstance(ASyncIterator.wrap(async_gen()), ASyncIterator)

def test_iterable_sync():
assert [i for i in ASyncIterable(async_gen())] == [0, 1, 2]
assert [i for i in ASyncIterable.wrap(async_gen())] == [0, 1, 2]

@pytest.mark.asyncio_cooperative
async def test_iterable_async():
assert [i async for i in ASyncIterable(async_gen())] == [0, 1, 2]
assert [i async for i in ASyncIterable.wrap(async_gen())] == [0, 1, 2]

def test_iterator_sync():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ async def task():
@pytest.mark.asyncio_cooperative
async def test_task_mapping_init():
tasks = TaskMapping(_coro_fn)
assert tasks._coro_fn is _coro_fn
assert tasks._coro_fn_kwargs == {}
assert tasks._wrapped_func is _coro_fn
assert tasks._wrapped_func_kwargs == {}
assert tasks._name == ""
tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None)
assert tasks._coro_fn_kwargs == {'kwarg0': 1, 'kwarg1': None}
assert tasks._wrapped_func_kwargs == {'kwarg0': 1, 'kwarg1': None}
assert tasks._name == "test"

@pytest.mark.asyncio_cooperative
Expand Down

0 comments on commit 6f4ffdc

Please sign in to comment.