Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: TaskMapping.map #115

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
env/
.hypothesis/
.mypy_cache/
.pytest_cache/
__pycache__/
Expand Down
4 changes: 4 additions & 0 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def __init__(self):
err = f"The event loop is already running, which means you're trying to use an ASync function synchronously from within an async context.\n"
err += f"Check your traceback to determine which, then try calling asynchronously instead with one of the following kwargs:\n"
err += f"{_flags.VIABLE_FLAGS}"

class MappingNotEmptyError(Exception):
def __init__(self):
super().__init__("TaskMapping already contains some data. In order to use `map`, you need a fresh one.")
23 changes: 14 additions & 9 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import asyncio
import logging

from a_sync._typing import K, P, T, V
from a_sync._typing import *
from a_sync import exceptions
from a_sync.utils.iterators import as_yielded
from a_sync.utils.as_completed import as_completed

Expand Down Expand Up @@ -37,13 +37,14 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
super().__setitem__(item, task)
return task
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
assert not self, "must be empty"
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
if self:
raise exceptions.MappingNotEmptyError
async for key in as_yielded(*[_yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for key, value in self.yield_completed(pop=pop):
yield __yield(key, value, yields)
yield _yield(key, value, yields)
async for key, value in as_completed(self, aiter=True):
yield __yield(key, value, yields)
yield _yield(key, value, yields)
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
if task.done():
Expand All @@ -69,15 +70,19 @@ def __prune_persisted_tasks():
raise e
__persisted_tasks.discard(task)

def __yield(key: Any, value: Any, yields: Literal['keys', 'both']):
@overload
def _yield(key: K, value: V, yields: Literal['keys']) -> K:...
@overload
def _yield(key: K, value: V, yields: Literal['both']) -> Tuple[K, V]:...
def _yield(key: K, value: V, yields: Literal['keys', 'both']) -> Union[K, Tuple[K, V]]:
if yields == 'both':
yield key, value
return key, value
elif yields == 'keys':
yield key
return key
else:
raise ValueError(f"`yields` must be 'keys' or 'both'. You passed {yields}")

async def __yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]:
async def _yield_keys(iterable: AnyIterable[K]) -> AsyncIterator[K]:
if isinstance(iterable, AsyncIterable):
async for key in iterable:
yield key
Expand Down
57 changes: 56 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import pytest

from a_sync import create_task
from a_sync import TaskMapping, create_task, exceptions

@pytest.mark.asyncio_cooperative
async def test_create_task():
Expand All @@ -27,3 +27,58 @@ async def task():
await asyncio.sleep(0)
# previously, it failed here
create_task(coro=task(), skip_gc_until_done=True)

@pytest.mark.asyncio_cooperative
async def test_task_mapping_init():
tasks = TaskMapping(_coro_fn)
assert tasks._coro_fn is _coro_fn
assert tasks._coro_fn_kwargs == {}
assert tasks._name == ""
tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None)
assert tasks._coro_fn_kwargs == {'kwarg0': 1, 'kwarg1': None}
assert tasks._name == "test"

@pytest.mark.asyncio_cooperative
async def test_task_mapping():
tasks = TaskMapping(_coro_fn)
# does it return the correct type
assert isinstance(tasks[0], asyncio.Task)
# does it correctly return existing values
assert tasks[1] is tasks[1]
# does the task return the correct value
assert await tasks[0] == "1"
# can it do it again
assert await tasks[0] == "1"

@pytest.mark.asyncio_cooperative
async def test_task_mapping_map_with_sync_iter():
tasks = TaskMapping(_coro_fn)
async for k, v in tasks.map(range(5)):
assert isinstance(k, int)
assert isinstance(v, str)
with pytest.raises(exceptions.MappingNotEmptyError):
async for k in tasks.map(range(5), yields='keys'):
assert isinstance(k, int)
tasks = TaskMapping(_coro_fn)
async for k in tasks.map(range(5), yields='keys'):
assert isinstance(k, int)

@pytest.mark.asyncio_cooperative
async def test_task_mapping_map_with_async_iter():
async def async_iter():
for i in range(5):
yield i
tasks = TaskMapping(_coro_fn)
async for k, v in tasks.map(async_iter()):
assert isinstance(k, int)
assert isinstance(v, str)
with pytest.raises(exceptions.MappingNotEmptyError):
async for k in tasks.map(async_iter(), yields='keys'):
assert isinstance(k, int)
tasks = TaskMapping(_coro_fn)
async for k in tasks.map(async_iter(), yields='keys'):
assert isinstance(k, int)

async def _coro_fn(i: int) -> str:
i += 1
return str(i) * i
Loading