diff --git a/.gitignore b/.gitignore index cbd36b4c..79aeec4d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ env/ +.hypothesis/ .mypy_cache/ .pytest_cache/ __pycache__/ diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 172f8801..30ee7276 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -66,3 +66,7 @@ def __init__(self): err = f"The event loop is already running, which means you're trying to use an ASync function synchronously from within an async context.\n" err += f"Check your traceback to determine which, then try calling asynchronously instead with one of the following kwargs:\n" err += f"{_flags.VIABLE_FLAGS}" + +class MappingNotEmptyError(Exception): + def __init__(self): + super().__init__("TaskMapping already contains some data. In order to use `map`, you need a fresh one.") diff --git a/a_sync/task.py b/a_sync/task.py index 90f283f9..073645bd 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -2,8 +2,8 @@ import asyncio import logging -from a_sync._typing import K, P, T, V from a_sync._typing import * +from a_sync import exceptions from a_sync.utils.iterators import as_yielded from a_sync.utils.as_completed import as_completed @@ -37,13 +37,14 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": super().__setitem__(item, task) return task async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: - assert not self, "must be empty" - async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] + if self: + raise exceptions.MappingNotEmptyError + async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] self[key] # ensure task is running async for key, value in self.yield_completed(pop=pop): - yield __yield(key, value, yields) + yield _yield(key, value, yields) async for key, value in as_completed(self, aiter=True): - yield __yield(key, value, yields) + yield _yield(key, value, yields) async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: for k, task in dict(self).items(): if task.done(): @@ -69,15 +70,19 @@ def __prune_persisted_tasks(): raise e __persisted_tasks.discard(task) -def __yield(key: Any, value: Any, yields: Literal['keys', 'both']): +@overload +def _yield(key: K, value: V, yields: Literal['keys']) -> K:... +@overload +def _yield(key: K, value: V, yields: Literal['both']) -> Tuple[K, V]:... +def _yield(key: K, value: V, yields: Literal['keys', 'both']) -> Union[K, Tuple[K, V]]: if yields == 'both': - yield key, value + return key, value elif yields == 'keys': - yield key + return key else: raise ValueError(f"`yields` must be 'keys' or 'both'. You passed {yields}") -async def __yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]: +async def _yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]: if isinstance(iterable, AsyncIterable): async for key in iterable: yield key diff --git a/tests/test_task.py b/tests/test_task.py index 61149985..71b183d0 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,7 +1,7 @@ import asyncio import pytest -from a_sync import create_task +from a_sync import TaskMapping, create_task, exceptions @pytest.mark.asyncio_cooperative async def test_create_task(): @@ -27,3 +27,58 @@ async def task(): await asyncio.sleep(0) # previously, it failed here create_task(coro=task(), skip_gc_until_done=True) + +@pytest.mark.asyncio_cooperative +async def test_task_mapping_init(): + tasks = TaskMapping(_coro_fn) + assert tasks._coro_fn is _coro_fn + assert tasks._coro_fn_kwargs == {} + assert tasks._name == "" + tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None) + assert tasks._coro_fn_kwargs == {'kwarg0': 1, 'kwarg1': None} + assert tasks._name == "test" + +@pytest.mark.asyncio_cooperative +async def test_task_mapping(): + tasks = TaskMapping(_coro_fn) + # does it return the correct type + assert isinstance(tasks[0], asyncio.Task) + # does it correctly return existing values + assert tasks[1] is tasks[1] + # does the task return the correct value + assert await tasks[0] == "1" + # can it do it again + assert await tasks[0] == "1" + +@pytest.mark.asyncio_cooperative +async def test_task_mapping_map_with_sync_iter(): + tasks = TaskMapping(_coro_fn) + async for k, v in tasks.map(range(5)): + assert isinstance(k, int) + assert isinstance(v, str) + with pytest.raises(exceptions.MappingNotEmptyError): + async for k in tasks.map(range(5), yields='keys'): + assert isinstance(k, int) + tasks = TaskMapping(_coro_fn) + async for k in tasks.map(range(5), yields='keys'): + assert isinstance(k, int) + +@pytest.mark.asyncio_cooperative +async def test_task_mapping_map_with_async_iter(): + async def async_iter(): + for i in range(5): + yield i + tasks = TaskMapping(_coro_fn) + async for k, v in tasks.map(async_iter()): + assert isinstance(k, int) + assert isinstance(v, str) + with pytest.raises(exceptions.MappingNotEmptyError): + async for k in tasks.map(async_iter(), yields='keys'): + assert isinstance(k, int) + tasks = TaskMapping(_coro_fn) + async for k in tasks.map(async_iter(), yields='keys'): + assert isinstance(k, int) + +async def _coro_fn(i: int) -> str: + i += 1 + return str(i) * i \ No newline at end of file