Skip to content

Commit

Permalink
Emit correct type stubs for async functions wrapped with additional d…
Browse files Browse the repository at this point in the history
…ecorators (#194)

* Emit correct type stubs for async functions wrapped with additional decorators

* Address comments

* Simplify

* Lint
  • Loading branch information
mwaskom authored Dec 17, 2024
1 parent 9986bf6 commit 42f549b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
18 changes: 18 additions & 0 deletions src/synchronicity/async_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ async def wrapper(*args, **kwargs):
return functools.wraps(func)


def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool:
"""Determine if func returns a coroutine, unwrapping decorators, but not the async synchronicity interace."""
from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import

if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING:
return is_coroutine_function_follow_wrapped(func.__wrapped__)
return inspect.iscoroutinefunction(func)


def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool:
"""Determine if func returns an async generator, unwrapping decorators, but not the async synchronicity interace."""
from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import

if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING:
return is_async_gen_function_follow_wrapped(func.__wrapped__)
return inspect.isasyncgenfunction(func)


YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
SEND_TYPE = typing.TypeVar("SEND_TYPE")

Expand Down
8 changes: 3 additions & 5 deletions src/synchronicity/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from synchronicity.annotations import evaluated_annotation
from synchronicity.combined_types import FunctionWithAio, MethodWithAio

from .async_wrap import wraps_by_interface
from .async_wrap import is_async_gen_function_follow_wrapped, is_coroutine_function_follow_wrapped, wraps_by_interface
from .callback import Callback
from .exceptions import UserCodeException, unwrap_coro_exception, wrap_coro_exception
from .interface import DEFAULT_CLASS_PREFIX, DEFAULT_FUNCTION_PREFIXES, Interface
Expand Down Expand Up @@ -79,7 +79,7 @@ def _type_requires_aio_usage(annotation, declaration_module):

def should_have_aio_interface(func):
# determines if a blocking function gets an .aio attribute with an async interface to the function or not
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if is_coroutine_function_follow_wrapped(func) or is_async_gen_function_follow_wrapped(func):
return True
# check annotations if they contain any async entities that would need an event loop to be translated:
# This catches things like vanilla functions returning Coroutines
Expand Down Expand Up @@ -468,8 +468,6 @@ def _wrap_callable(
else:
_name = name

is_coroutinefunction = inspect.iscoroutinefunction(f)

@wraps_by_interface(interface, f)
def f_wrapped(*args, **kwargs):
return_future = kwargs.pop(_RETURN_FUTURE_KWARG, False)
Expand Down Expand Up @@ -499,7 +497,7 @@ def f_wrapped(*args, **kwargs):
elif is_coroutine:
if interface == Interface._ASYNC_WITH_BLOCKING_TYPES:
coro = self._run_function_async(res, f)
if not is_coroutinefunction:
if not is_coroutine_function_follow_wrapped(f):
# If this is a non-async function that returns a coroutine,
# then this is the exit point, and we need to unwrap any
# wrapped exception here. Otherwise, the exit point is
Expand Down
4 changes: 3 additions & 1 deletion src/synchronicity/type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import synchronicity
from synchronicity import combined_types, overload_tracking
from synchronicity.annotations import TYPE_CHECKING_OVERRIDES, evaluated_annotation
from synchronicity.async_wrap import is_coroutine_function_follow_wrapped
from synchronicity.interface import Interface
from synchronicity.synchronizer import (
SYNCHRONIZER_ATTR,
Expand Down Expand Up @@ -390,6 +391,7 @@ def final_transform_signature(sig):
{aio_func_source}
{body_indent}{entity_name}: __{entity_name}_spec{parent_type_var_names_spec}
"""

return protocol_attr

def _prepare_method_generic_type_vars(self, entity, parent_generic_type_vars):
Expand Down Expand Up @@ -856,7 +858,7 @@ def _get_function_source(
maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n"

async_prefix = ""
if inspect.iscoroutinefunction(func):
if is_coroutine_function_follow_wrapped(func):
# note: async prefix should not be used for annotated abstract/stub *async generators*,
# so we don't check for inspect.isasyncgenfunction since they contain no yield keyword,
# and would otherwise indicate an awaitable that returns an async generator to static type checkers
Expand Down
10 changes: 10 additions & 0 deletions test/type_stub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def wrapper(extra_arg: int, *args, **kwargs):
assert _function_source(wrapper) == "def orig(extra_arg: int, arg: float):\n ...\n"


def test_wrapped_async_func_remains_async():
async def orig(arg: str): ...

@functools.wraps(orig)
def wrapper(*args, **kwargs):
return orig(*args, **kwargs)

assert _function_source(wrapper) == "async def orig(arg: str):\n ...\n"


class Base:
def base_method(self) -> str:
return ""
Expand Down

0 comments on commit 42f549b

Please sign in to comment.