Skip to content

Commit

Permalink
feat: cythonize iter.py (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 20, 2024
1 parent 067aa6a commit 8a63e48
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 35 deletions.
4 changes: 3 additions & 1 deletion a_sync/a_sync/_helpers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ cdef object _asyncify(object func, object executor): # type: ignore [misc]
if asyncio.iscoroutinefunction(func) or isinstance(func, ASyncFunction):
raise exceptions.FunctionNotSync(func)

cdef object sumbit = executor.submit
cdef object sumbit

submit = executor.submit

@functools.wraps(func)
async def _asyncify_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
Expand Down
89 changes: 55 additions & 34 deletions a_sync/iter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ else:
ViewFn = AnyFn[[T], bool]


_FORMAT_PATTERNS = ("{cls}", "{obj}")
cdef tuple[str] _FORMAT_PATTERNS = ("{cls}", "{obj}")


cdef object get_event_loop = asyncio.get_event_loop


class _AwaitableAsyncIterableMixin(AsyncIterable[T]):
Expand Down Expand Up @@ -108,8 +111,10 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]):
def __init_subclass__(cls, **kwargs) -> None:
# Determine the type used for T in the subclass
type_argument = T # Default value
type_string = ":obj:`T` objects"
cdef object type_argument = T # Default value
cdef str type_string = ":obj:`T` objects"
cdef object base, args
for base in getattr(cls, "__orig_bases__", []):
if not hasattr(base, "__args__"):
continue
Expand All @@ -126,46 +131,46 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]):
elif hasattr(type_argument, "__module__") and hasattr(
type_argument, "__qualname__"
):
type_string = f":class:`~{type_argument.__module__}.{type_argument.__qualname__}`"
type_string = ":class:`~{}.{}`".format(type_argument.__module__, type_argument.__qualname__)
elif hasattr(type_argument, "__module__") and hasattr(
type_argument, "__name__"
):
type_string = (
f":class:`~{type_argument.__module__}.{type_argument.__name__}`"
":class:`~{}.{}`".format(type_argument.__module__, type_argument.__name__)
)
elif hasattr(type_argument, "__qualname__"):
type_string = f":class:`{type_argument.__qualname__}`"
type_string = ":class:`{}`".format(type_argument.__qualname__)
elif hasattr(type_argument, "__name__"):
type_string = f":class:`{type_argument.__name__}`"
type_string = ":class:`{}`".format(type_argument.__name__)
else:
type_string = str(type_argument)
# modify the class docstring
new = (
f"When awaited, a list of all {type_string} will be returned.\n"
cdef str new_chunk = (
"When awaited, a list of all {} will be returned.\n".format(type_string) +
"\n"
"Example:\n"
f" >>> my_object = {cls.__name__}(...)\n"
" >>> my_object = {}(...)\n".format(cls.__name__) +
" >>> all_contents = await my_object\n"
" >>> isinstance(all_contents, list)\n"
" True\n"
f" >>> isinstance(all_contents[0], {type_argument.__name__})\n"
" >>> isinstance(all_contents[0], {})\n".format(type_argument.__name__) +
" True\n"
)
if cls.__doc__ is None:
cls.__doc__ = new
cls.__doc__ = new_chunk
elif not cls.__doc__ or cls.__doc__.endswith("\n\n"):
cls.__doc__ += new
cls.__doc__ += new_chunk
elif cls.__doc__.endswith("\n"):
cls.__doc__ += f"\n{new}"
cls.__doc__ += "\n{}".format(new_chunk)
else:
cls.__doc__ += f"\n\n{new}"
cls.__doc__ += "\n\n{}".format(new_chunk)
# Update method docstrings by redefining methods
# This is necessary because, by default, subclasses inherit methods from their bases
# which means if we just update the docstring we might edit docs for unrelated objects
functions_to_redefine = {
cdef dict functions_to_redefine = {
attr_name: attr_value
for attr_name in dir(cls)
if (attr_value := getattr(cls, attr_name, None))
Expand All @@ -174,6 +179,8 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]):
and any(pattern in attr_value.__doc__ for pattern in _FORMAT_PATTERNS)
}
cdef str function_name
cdef object function_obj
for function_name, function_obj in functions_to_redefine.items():
# Create a new function object with the docstring formatted appropriately for this class
redefined_function_obj = FunctionType(
Expand Down Expand Up @@ -246,16 +253,16 @@ class ASyncIterable(_AwaitableAsyncIterableMixin[T], Iterable[T]):
"""
if not isinstance(async_iterable, AsyncIterable):
raise TypeError(
f"`async_iterable` must be an AsyncIterable. You passed {async_iterable}"
"`async_iterable` must be an AsyncIterable. You passed {}".format(async_iterable)
)
self.__wrapped__ = async_iterable
"The wrapped async iterable object."
def __repr__(self) -> str:
start = f"<{type(self).__name__}"
start = "<{}".format(type(self).__name__)
if wrapped := getattr(self, "__wrapped__", None):
start += f" for {self.__wrapped__}"
return f"{start} at {hex(id(self))}>"
start += " for {}".format(self.__wrapped__)
return "{} at {}>".format(start, hex(id(self)))
def __aiter__(self) -> AsyncIterator[T]:
"""
Expand Down Expand Up @@ -322,7 +329,7 @@ class ASyncIterator(_AwaitableAsyncIterableMixin[T], Iterator[T]):

"""
try:
return asyncio.get_event_loop().run_until_complete(self.__anext__())
return get_event_loop().run_until_complete(self.__anext__())
except StopAsyncIteration as e:
raise StopIteration from e
except RuntimeError as e:
Expand Down Expand Up @@ -361,7 +368,7 @@ class ASyncIterator(_AwaitableAsyncIterableMixin[T], Iterator[T]):
elif inspect.isasyncgenfunction(wrapped):
return ASyncGeneratorFunction(wrapped)
raise TypeError(
f"`wrapped` must be an AsyncIterator or an async generator function. You passed {wrapped}"
"`wrapped` must be an AsyncIterator or an async generator function. You passed {}".format(wrapped)
)
def __init__(self, async_iterator: AsyncIterator[T]):
Expand All @@ -373,7 +380,7 @@ class ASyncIterator(_AwaitableAsyncIterableMixin[T], Iterator[T]):
"""
if not isinstance(async_iterator, AsyncIterator):
raise TypeError(
f"`async_iterator` must be an AsyncIterator. You passed {async_iterator}"
"`async_iterator` must be an AsyncIterator. You passed {}".format(async_iterator)
)
self.__wrapped__ = async_iterator
"The wrapped :class:`AsyncIterator`."
Expand Down Expand Up @@ -456,7 +463,11 @@ class ASyncGeneratorFunction(Generic[P, T]):
functools.update_wrapper(self, self.__wrapped__)
def __repr__(self) -> str:
return f"<{type(self).__name__} for {self.__wrapped__} at {hex(id(self))}>"
return "<{} for {} at {}>".format(
type(self).__name__,
self.__wrapped__,
hex(id(self))
)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ASyncIterator[T]:
"""
Expand All @@ -474,6 +485,8 @@ class ASyncGeneratorFunction(Generic[P, T]):
"Descriptor method to make the function act like a non-data descriptor."
if instance is None:
return self
cdef object gen_func
try:
gen_func = instance.__dict__[self.field_name]
except KeyError:
Expand All @@ -485,17 +498,18 @@ class ASyncGeneratorFunction(Generic[P, T]):
@property
def __self__(self) -> object:
cdef object instance
try:
instance = self.__weakself__()
except TypeError:
raise AttributeError(f"{self} has no attribute '__self__'") from None
raise AttributeError("{} has no attribute '__self__'".format(self)) from None
if instance is None:
raise ReferenceError(self)
return instance
def __get_cache_handle(self, instance: object) -> asyncio.TimerHandle:
# NOTE: we create a strong reference to instance here. I'm not sure if this is good or not but its necessary for now.
return asyncio.get_event_loop().call_later(
return get_event_loop().call_later(
300, delattr, instance, self.field_name
)
Expand Down Expand Up @@ -534,7 +548,7 @@ class _ASyncView(ASyncIterator[T]):
self.__iterator__ = iterable.__iter__()
else:
raise TypeError(
f"`iterable` must be AsyncIterable or Iterable, you passed {iterable}"
"`iterable` must be AsyncIterable or Iterable, you passed {}".format(iterable)
)
Expand All @@ -561,9 +575,12 @@ class ASyncFilter(_ASyncView[T]):
"""
def __repr__(self) -> str:
return f"<ASyncFilter for iterator={self.__wrapped__} function={self._function.__name__} at {hex(id(self))}>"
return "<ASyncFilter for iterator={} function={} at {}>".format(
self.__wrapped__, self._function.__name__, hex(id(self))
)
async def __anext__(self) -> T:
cdef object obj
if self.__aiterator__:
async for obj in self.__aiterator__:
if await self._check(obj):
Expand All @@ -589,11 +606,11 @@ class ASyncFilter(_ASyncView[T]):
Returns:
True if the object passes the filter, False otherwise.
"""
checked = self._function(obj)
cdef object checked = self._function(obj)
return bool(await checked) if inspect.isawaitable(checked) else bool(checked)
def _key_if_no_key(obj: T) -> T:
cdef object _key_if_no_key(object obj):
"""
Default key function that returns the object itself if no key is provided.

Expand Down Expand Up @@ -657,17 +674,18 @@ class ASyncSorter(_ASyncView[T]):
An async iterator that will yield the sorted {obj}.
"""
if self._consumed:
raise RuntimeError(f"{self} has already been consumed")
raise RuntimeError("{} has already been consumed".format(self))
return self
def __repr__(self) -> str:
cdef str rep
rep = "<ASyncSorter"
if self.reversed:
rep += " reversed"
rep += f" for iterator={self.__wrapped__}"
rep += " for iterator={}".format(self.__wrapped__)
if self._function is not _key_if_no_key:
rep += f" key={self._function.__name__}"
rep += f" at {hex(id(self))}>"
rep += " key={}".format(self._function.__name__)
rep += " at {}>".format(hex(id(self)))
return rep
def __anext__(self) -> T:
Expand All @@ -683,6 +701,9 @@ class ASyncSorter(_ASyncView[T]):
Returns:
An async iterator that will yield the sorted items.
"""
cdef list items, sort_tasks
cdef object obj
if asyncio.iscoroutinefunction(self._function):
items = []
sort_tasks = []
Expand Down

0 comments on commit 8a63e48

Please sign in to comment.