Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smarter #275

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions a_sync/a_sync/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import a_sync.asyncio
from a_sync import exceptions
from a_sync._typing import *
from a_sync.a_sync.function import ASyncFunction


def _await(awaitable: Awaitable[T]) -> T:
Expand All @@ -15,9 +14,10 @@ def _await(awaitable: Awaitable[T]) -> T:
except RuntimeError as e:
if str(e) == "This event loop is already running":
raise exceptions.SyncModeInAsyncContextError from None
raise e
raise

def _asyncify(func: SyncFn[P, T], executor: Executor) -> CoroFn[P, T]: # type: ignore [misc]
from a_sync.a_sync.function import ASyncFunction
if asyncio.iscoroutinefunction(func) or isinstance(func, ASyncFunction):
raise exceptions.FunctionNotSync(func)
@functools.wraps(func)
Expand Down
58 changes: 40 additions & 18 deletions a_sync/a_sync/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,47 @@ def fn(self): # -> Union[SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]], SyncFn[[Sync
"""Returns the final wrapped version of 'self._fn' decorated with all of the a_sync goodness."""
return self._async_wrap if self._async_def else self._sync_wrap

def map(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> "TaskMapping[P, T]":
from a_sync import TaskMapping
return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **kwargs)

async def any(self, *iterables: AnyIterable[object], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).any(pop=True, sync=False)

async def all(self, *iterables: AnyIterable[object], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).all(pop=True, sync=False)

async def min(self, *iterables: AnyIterable[object], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).min(pop=True, sync=False)

async def max(self, *iterables: AnyIterable[object], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).max(pop=True, sync=False)

async def sum(self, *iterables: AnyIterable[object], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).sum(pop=True, sync=False)
if sys.version_info >= (3, 11) or TYPE_CHECKING:
# we can specify P.args in python>=3.11 but in lower versions it causes a crash. Everything should still type check correctly on all versions.
def map(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> "TaskMapping[P, T]":
from a_sync import TaskMapping
return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **function_kwargs)

async def any(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).any(pop=True, sync=False)

async def all(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).all(pop=True, sync=False)

async def min(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).min(pop=True, sync=False)

async def max(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).max(pop=True, sync=False)

async def sum(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).sum(pop=True, sync=False)

else:
def map(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> "TaskMapping[P, T]":
from a_sync import TaskMapping
return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **function_kwargs)

async def any(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).any(pop=True, sync=False)

async def all(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).all(pop=True, sync=False)

async def min(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).min(pop=True, sync=False)

async def max(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).max(pop=True, sync=False)

async def sum(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T:
return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).sum(pop=True, sync=False)

@functools.cached_property
def _sync_default(self) -> bool:
"""If user did not specify a default, we defer to the function. 'def' vs 'async def'"""
Expand Down
2 changes: 1 addition & 1 deletion a_sync/asyncio/as_completed.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: bool = False)
except Exception as e:
if return_exceptions:
return k, e
raise e
raise

__all__ = ["as_completed", "as_completed_mapping"]
2 changes: 1 addition & 1 deletion a_sync/asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def get_event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.get_event_loop()
except RuntimeError as e: # Necessary for use with multi-threaded applications.
if not str(e).startswith("There is no current event loop in thread"):
raise e
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
10 changes: 6 additions & 4 deletions a_sync/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ def __cancel_cache_handle(self, instance: object) -> None:
self._cache_handle.cancel()

class _ASyncView(ASyncIterator[T]):
__aiterator__ = None
__iterator__ = None
__aiterator__: Optional[AsyncIterator[T]] = None
__iterator__: Optional[Iterator[T]] = None
def __init__(
self,
function: ViewFn[T],
iterable: AsyncIterable[T],
iterable: AnyIterable[T],
) -> None:
self._function = function
self.__wrapped__ = iterable
Expand All @@ -184,7 +184,7 @@ def __init__(
elif isinstance(iterable, Iterable):
self.__iterator__ = iterable.__iter__()
else:
raise TypeError(f"`iterable` must be AsyncIterable or Iterabe, you passed {iterable}")
raise TypeError(f"`iterable` must be AsyncIterable or Iterable, you passed {iterable}")

@final
class ASyncFilter(_ASyncView[T]):
Expand All @@ -202,6 +202,8 @@ async def __anext__(self) -> T:
return obj
except StopIteration:
pass
else:
raise TypeError(self.__wrapped__)
raise StopAsyncIteration from None
async def _check(self, obj: T) -> bool:
checked = self._function(obj)
Expand Down
44 changes: 19 additions & 25 deletions a_sync/primitives/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __repr__(self) -> str:
if self._unfinished_tasks:
repr_string += f" pending={self._unfinished_tasks}"
return f"{repr_string}>"
# NOTE: asyncio defines both this and __repr__
def __str__(self) -> str:
repr_string = f"<{type(self).__name__}"
if self._name:
Expand Down Expand Up @@ -196,27 +197,25 @@ async def __worker_coro(self) -> NoReturn:
while True:
try:
args, kwargs, fut = await self.get()
if fut is None:
# the weakref was already cleaned up, we don't need to process this item
self.task_done()
continue
result = await self.func(*args, **kwargs)
fut.set_result(result)
except asyncio.exceptions.InvalidStateError:
logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result)
except Exception as e:
try:
fut.set_exception(e)
if fut is None:
# the weakref was already cleaned up, we don't need to process this item
self.task_done()
continue
result = await self.func(*args, **kwargs)
fut.set_result(result)
except asyncio.exceptions.InvalidStateError:
logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e)
except UnboundLocalError as u:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
if str(e) != "local variable 'fut' referenced before assignment":
logger.exception(u)
raise u
logger.exception(e)
raise e
self.task_done()
logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result)
except Exception as e:
try:
fut.set_exception(e)
except asyncio.exceptions.InvalidStateError:
logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e)
self.task_done()
except Exception as e:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
logger.exception(e)
raise


def _validate_args(i: int, can_return_less: bool) -> None:
Expand Down Expand Up @@ -337,13 +336,8 @@ async def __worker_coro(self) -> NoReturn:
fut.set_exception(e)
except asyncio.exceptions.InvalidStateError:
logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e)
except UnboundLocalError as u:
if str(e) != "local variable 'fut' referenced before assignment":
logger.exception(u)
raise u
raise e
self.task_done()
except Exception as e:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
logger.exception(e)
raise e
raise
10 changes: 5 additions & 5 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,12 @@ def __init__(
@functools.wraps(wrapped_func)
async def _wrapped_set_next(*args: P.args, __a_sync_recursion: int = 0, **kwargs: P.kwargs) -> V:
try:
retval = await wrapped_func(*args, **kwargs)
self._next.set()
self._next.clear()
return retval
return await wrapped_func(*args, **kwargs)
except exceptions.SyncModeInAsyncContextError as e:
raise Exception(e, self.__wrapped__)
except TypeError as e:
if __a_sync_recursion > 2 or not (str(e).startswith(wrapped_func.__name__) and "got multiple values for argument" in str(e)):
raise e
raise
# NOTE: args ordering is clashing with provided kwargs. We can handle this in a hacky way.
# TODO: perform this check earlier and pre-prepare the args/kwargs ordering
new_args = list(args)
Expand All @@ -127,6 +124,9 @@ async def _wrapped_set_next(*args: P.args, __a_sync_recursion: int = 0, **kwargs
return await _wrapped_set_next(*new_args, **new_kwargs, __a_sync_recursion=__a_sync_recursion+1)
except TypeError as e2:
raise e.with_traceback(e.__traceback__) if str(e2) == "unsupported callable" else e2.with_traceback(e2.__traceback__)
finally:
self._next.set()
self._next.clear()
self._wrapped_func = _wrapped_set_next
init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue()
self.__init_loader_coro = exhaust_iterator(self._tasks_for_iterables(*iterables), queue=init_loader_queue)
Expand Down
4 changes: 2 additions & 2 deletions a_sync/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def any(*awaitables) -> bool:
if str(e) == "cannot reuse already awaited coroutine":
raise RuntimeError(str(e), fut) from e
else:
raise e
raise
if bool(result):
for fut in futs:
fut.cancel()
Expand All @@ -37,7 +37,7 @@ async def all(*awaitables) -> bool:
if str(e) == "cannot reuse already awaited coroutine":
raise RuntimeError(str(e), fut) from e
else:
raise e
raise
if not result:
for fut in futs:
fut.cancel()
Expand Down
3 changes: 2 additions & 1 deletion a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None,
"""
for x in await asyncio.gather(*[exhaust_iterator(iterator, queue=queue) for iterator in iterators], return_exceptions=True):
if isinstance(x, Exception):
raise x
# raise it with its original traceback instead of from here
raise x.with_traceback(x.__traceback__)
if queue:
queue.put_nowait(_Done())
if join:
Expand Down
Loading