Skip to content

Commit

Permalink
fix: TaskMapping.map
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Feb 12, 2024
1 parent ae3ae84 commit c1dddb8
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
env/
.hypothesis/
.mypy_cache/
.pytest_cache/
__pycache__/
Expand Down
4 changes: 4 additions & 0 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def __init__(self):
err = f"The event loop is already running, which means you're trying to use an ASync function synchronously from within an async context.\n"
err += f"Check your traceback to determine which, then try calling asynchronously instead with one of the following kwargs:\n"
err += f"{_flags.VIABLE_FLAGS}"

class MappingNotEmptyError(Exception):
def __init__(self):
super().__init__("TaskMapping already contains some data. In order to use `map`, you need a fresh one.")
23 changes: 14 additions & 9 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import asyncio
import logging

from a_sync._typing import K, P, T, V
from a_sync._typing import *
from a_sync import exceptions
from a_sync.utils.iterators import as_yielded
from a_sync.utils.as_completed import as_completed

Expand Down Expand Up @@ -37,13 +37,14 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
super().__setitem__(item, task)
return task
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
assert not self, "must be empty"
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
if self:
raise exceptions.MappingNotEmptyError
async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for key, value in self.yield_completed(pop=pop):
yield __yield(key, value, yields)
yield _yield(key, value, yields)
async for key, value in as_completed(self, aiter=True):
yield __yield(key, value, yields)
yield _yield(key, value, yields)
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
if task.done():
Expand All @@ -69,15 +70,19 @@ def __prune_persisted_tasks():
raise e
__persisted_tasks.discard(task)

def __yield(key: Any, value: Any, yields: Literal['keys', 'both']):
@overload
def _yield(key: K, value: V, yields: Literal['keys']) -> K:...
@overload
def _yield(key: K, value: V, yields: Literal['both']) -> Tuple[K, V]:...
def _yield(key: K, value: V, yields: Literal['keys', 'both']) -> Union[K, Tuple[K, V]]:
if yields == 'both':
yield key, value
return key, value
elif yields == 'keys':
yield key
return key
else:
raise ValueError(f"`yields` must be 'keys' or 'both'. You passed {yields}")

async def __yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]:
async def _yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]:
if isinstance(iterable, AsyncIterable):
async for key in iterable:
yield key
Expand Down
57 changes: 56 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import pytest

from a_sync import create_task
from a_sync import TaskMapping, create_task, exceptions

@pytest.mark.asyncio_cooperative
async def test_create_task():
Expand All @@ -27,3 +27,58 @@ async def task():
await asyncio.sleep(0)
# previously, it failed here
create_task(coro=task(), skip_gc_until_done=True)

@pytest.mark.asyncio_cooperative
async def test_task_mapping_init():
tasks = TaskMapping(_coro_fn)
assert tasks._coro_fn is _coro_fn
assert tasks._coro_fn_kwargs == {}
assert tasks._name == ""
tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None)
assert tasks._coro_fn_kwargs == {'kwarg0': 1, 'kwarg1': None}
assert tasks._name == "test"

@pytest.mark.asyncio_cooperative
async def test_task_mapping():
tasks = TaskMapping(_coro_fn)
# does it return the correct type
assert isinstance(tasks[0], asyncio.Task)
# does it correctly return existing values
assert tasks[1] is tasks[1]
# does the task return the correct value
assert await tasks[0] == "1"
# can it do it again
assert await tasks[0] == "1"

@pytest.mark.asyncio_cooperative
async def test_task_mapping_map_with_sync_iter():
tasks = TaskMapping(_coro_fn)
async for k, v in tasks.map(range(5)):
assert isinstance(k, int)
assert isinstance(v, str)
with pytest.raises(exceptions.MappingNotEmptyError):
async for k in tasks.map(range(5), yields='keys'):
assert isinstance(k, int)
tasks = TaskMapping(_coro_fn)
async for k in tasks.map(range(5), yields='keys'):
assert isinstance(k, int)

@pytest.mark.asyncio_cooperative
async def test_task_mapping_map_with_async_iter():
async def async_iter():
for i in range(5):
yield i
tasks = TaskMapping(_coro_fn)
async for k, v in tasks.map(async_iter()):
assert isinstance(k, int)
assert isinstance(v, str)
with pytest.raises(exceptions.MappingNotEmptyError):
async for k in tasks.map(async_iter(), yields='keys'):
assert isinstance(k, int)
tasks = TaskMapping(_coro_fn)
async for k in tasks.map(async_iter(), yields='keys'):
assert isinstance(k, int)

async def _coro_fn(i: int) -> str:
i += 1
return str(i) * i

0 comments on commit c1dddb8

Please sign in to comment.