From 6f4ffdc6519d01ae1ae8e116e281d7b4a9ff9d8a Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:15:02 -0500 Subject: [PATCH] chore: refactor iter module (#152) --- a_sync/iter.py | 61 +++++++++++++++++++++++++++++++--------------- a_sync/task.py | 47 +++++++++++++++++------------------ tests/test_iter.py | 11 ++++++--- tests/test_task.py | 6 ++--- 4 files changed, 75 insertions(+), 50 deletions(-) diff --git a/a_sync/iter.py b/a_sync/iter.py index 11f52830..8e36ed9e 100644 --- a/a_sync/iter.py +++ b/a_sync/iter.py @@ -1,37 +1,60 @@ import asyncio +import functools +import inspect +import logging from a_sync._typing import * +logger = logging.getLogger(__name__) + class ASyncIterable(AsyncIterable[T], Iterable[T]): - """An abstract iterable implementation that can be used in both a `for` loop and an `async for` loop.""" + """A hybrid Iterable/AsyncIterable implementation that can be used in both a `for` loop and an `async for` loop.""" def __iter__(self) -> Iterator[T]: - yield from ASyncIterator.wrap(self.__aiter__()) + yield from ASyncIterator(self.__aiter__()) @classmethod - def wrap(self, aiterable: AsyncIterable[T]) -> "ASyncWrappedIterable[T]": - return ASyncWrappedIterable(aiterable) + def wrap(cls, wrapped: AsyncIterable[T]) -> "ASyncIterable[T]": + # NOTE: for backward-compatability only. Will be removed soon. + logger.warning("ASyncIterable.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterable(wrapped)`") + return cls(wrapped) + def __init__(self, async_iterable: AsyncIterable[T]): + self.__wrapped__ = async_iterable + def __aiter__(self) -> AsyncIterator[T]: + return self.__wrapped__.__aiter__() + __slots__ = "__wrapped__", + +AsyncGenFunc = Callable[P, AsyncGenerator[T, None]] class ASyncIterator(AsyncIterator[T], Iterator[T]): - """An abstract iterator implementation that can be used in both a `for` loop and an `async for` loop.""" + """A hybrid Iterator/AsyncIterator implementation that can be used in both a `for` loop and an `async for` loop.""" def __next__(self) -> T: try: return asyncio.get_event_loop().run_until_complete(self.__anext__()) except StopAsyncIteration as e: raise StopIteration from e + @overload + def wrap(cls, aiterator: AsyncIterator[T]) -> "ASyncIterator[T]":... + @overload + def wrap(cls, async_gen_func: AsyncGenFunc[P, T]) -> "ASyncGeneratorFunction[P, T]":... @classmethod - def wrap(self, aiterator: AsyncIterator[T]) -> "ASyncWrappedIterator[T]": - return ASyncWrappedIterator(aiterator) - -class ASyncWrappedIterable(ASyncIterable[T]): - __slots__ = "__aiterable", - def __init__(self, async_iterable: AsyncIterable[T]): - self.__aiterable = async_iterable - def __aiter__(self) -> AsyncIterator[T]: - return self.__aiterable.__aiter__() - -class ASyncWrappedIterator(ASyncIterator[T]): - __slots__ = "__aiterator", + def wrap(cls, wrapped): + if isinstance(wrapped, AsyncIterator): + logger.warning("This use case for ASyncIterator.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterator(wrapped)`") + return cls(wrapped) + elif inspect.isasyncgenfunction(wrapped): + return ASyncGeneratorFunction(wrapped) + raise TypeError(f"`wrapped` must be an AsyncIterator or an async generator function. You passed {wrapped}") def __init__(self, async_iterator: AsyncIterator[T]): - self.__aiterator = async_iterator + self.__wrapped__ = async_iterator async def __anext__(self) -> T: - return await self.__aiterator.__anext__() + return await self.__wrapped__.__anext__() + +class ASyncGeneratorFunction(Generic[P, T]): + __slots__ = "__wrapped__", + def __init__(self, async_gen_func: AsyncGenFunc[P, T]) -> None: + self.__wrapped__ = async_gen_func + functools.update_wrapper(self, self.__wrapped__) + assert inspect.isasyncgenfunction(self) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ASyncIterator[T]: + return ASyncIterator(self.__wrapped__(*args, **kwargs)) + \ No newline at end of file diff --git a/a_sync/task.py b/a_sync/task.py index d93007b8..9e6a772c 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -24,29 +24,32 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until MappingFn = Callable[Concatenate[K, P], Awaitable[V]] -class TaskMapping(ASyncIterable[Tuple[K, V]], DefaultDict[K, "asyncio.Task[V]"]): - __slots__ = "_coro_fn", "_coro_fn_kwargs", "_name", "_loader" - def __init__(self, coro_fn: MappingFn[K, P, V] = None, *iterables: AnyIterable[K], name: str = '', **coro_fn_kwargs: P.kwargs) -> None: - self._coro_fn = coro_fn - self._coro_fn_kwargs = coro_fn_kwargs +class TaskMapping(ASyncIterable[Tuple[K, V]], Mapping[K, "asyncio.Task[V]"]): + __slots__ = "_wrapped_func", "_wrapped_func_kwargs", "_name", "_tasks", "_init_loader" + def __init__(self, wrapped_func: MappingFn[K, P, V] = None, *iterables: AnyIterable[K], name: str = '', **wrapped_func_kwargs: P.kwargs) -> None: + self._wrapped_func = wrapped_func + self._wrapped_func_kwargs = wrapped_func_kwargs self._name = name - self._loader: Optional["asyncio.Task[None]"] + self._tasks: Dict[K, "asyncio.Task[V]"] = {} + self._init_loader: Optional["asyncio.Task[None]"] if iterables: - self._loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables))) + self._init_loader = create_task(exhaust_iterator(self._tasks_for_iterables(*iterables))) else: - self._loader = None + self._init_loader = None def __setitem__(self, item: Any, value: Any) -> None: raise NotImplementedError("You cannot manually set items in a TaskMapping") def __getitem__(self, item: K) -> "asyncio.Task[V]": try: - return super().__getitem__(item) + return self._tasks[item] except KeyError: task = create_task( - coro=self._coro_fn(item, **self._coro_fn_kwargs), + coro=self._wrapped_func(item, **self._wrapped_func_kwargs), name=f"{self._name}[{item}]" if self._name else f"{item}", ) - super().__setitem__(item, task) + self._tasks[item] = task return task + def __len__(self) -> int: + return len(self._tasks) def __await__(self) -> Generator[Any, None, Dict[K, V]]: """await all tasks and returns a mapping with the results for each key""" return self._await().__await__() @@ -54,15 +57,15 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]: """aiterate thru all key-task pairs, yielding the key-result pair as each task completes""" yielded = set() # if you inited the TaskMapping with some iterators, we will load those - if self._loader: - while not self._loader.done(): + if self._init_loader: + while not self._init_loader.done(): async for key, value in self.yield_completed(pop=False): if key not in yielded: yield _yield(key, value, "both") yielded.add(key) await asyncio.sleep(0) # loader is already done by this point, but we need to check for exceptions - await self._loader + await self._init_loader elif not self: # if you didn't init the TaskMapping with iterators and you didn't start any tasks manually, we should fail raise exceptions.MappingIsEmptyError @@ -71,10 +74,6 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]: async for key, value in as_completed(self, aiter=True): if key not in yielded: yield _yield(key, value, "both") - #def keys(self) -> KeysView[K]: - # if self._loader and not self._loader.done(): - # raise RuntimeError("the loader needs time to complete. bob will figure out a way to make this not impact sync users") - # return super().keys() async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: if self: raise exceptions.MappingNotEmptyError @@ -83,19 +82,19 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera async for _ in self._tasks_for_iterables(*iterables): async for key, value in self.yield_completed(pop=pop): yield _yield(key, value, yields) - async for key, value in as_completed(self, aiter=True): + async for key, value in as_completed(self._tasks, aiter=True): if pop: - self.pop(key) + self._tasks.pop(key) yield _yield(key, value, yields) async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: - for k, task in dict(self).items(): + for k, task in dict(self._tasks).items(): if task.done(): if pop: - task = self.pop(k) + task = self._tasks.pop(k) yield k, await task async def _await(self) -> Dict[K, V]: - if self._loader: - await self._loader + if self._init_loader: + await self._init_loader if not self: raise exceptions.MappingIsEmptyError return await gather(self) diff --git a/tests/test_iter.py b/tests/test_iter.py index a87f58b5..d09cd660 100644 --- a/tests/test_iter.py +++ b/tests/test_iter.py @@ -2,8 +2,7 @@ import pytest -from a_sync.iter import (ASyncIterable, ASyncIterator, ASyncWrappedIterable, - ASyncWrappedIterator) +from a_sync.iter import ASyncIterable, ASyncIterator async def async_gen(): @@ -12,16 +11,20 @@ async def async_gen(): yield 2 def test_iterable_wrap(): - assert isinstance(ASyncIterable.wrap(async_gen()), ASyncWrappedIterable) + assert isinstance(ASyncIterable(async_gen()), ASyncIterable) + assert isinstance(ASyncIterable.wrap(async_gen()), ASyncIterable) def test_iterator_wrap(): - assert isinstance(ASyncIterator.wrap(async_gen()), ASyncWrappedIterator) + assert isinstance(ASyncIterator(async_gen()), ASyncIterator) + assert isinstance(ASyncIterator.wrap(async_gen()), ASyncIterator) def test_iterable_sync(): + assert [i for i in ASyncIterable(async_gen())] == [0, 1, 2] assert [i for i in ASyncIterable.wrap(async_gen())] == [0, 1, 2] @pytest.mark.asyncio_cooperative async def test_iterable_async(): + assert [i async for i in ASyncIterable(async_gen())] == [0, 1, 2] assert [i async for i in ASyncIterable.wrap(async_gen())] == [0, 1, 2] def test_iterator_sync(): diff --git a/tests/test_task.py b/tests/test_task.py index 2ffd40e6..ebc7e5c6 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -31,11 +31,11 @@ async def task(): @pytest.mark.asyncio_cooperative async def test_task_mapping_init(): tasks = TaskMapping(_coro_fn) - assert tasks._coro_fn is _coro_fn - assert tasks._coro_fn_kwargs == {} + assert tasks._wrapped_func is _coro_fn + assert tasks._wrapped_func_kwargs == {} assert tasks._name == "" tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None) - assert tasks._coro_fn_kwargs == {'kwarg0': 1, 'kwarg1': None} + assert tasks._wrapped_func_kwargs == {'kwarg0': 1, 'kwarg1': None} assert tasks._name == "test" @pytest.mark.asyncio_cooperative