From 42f549b01d229c8b8588ebd73f136812adf6bc3b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 17 Dec 2024 09:24:53 -0500 Subject: [PATCH] Emit correct type stubs for async functions wrapped with additional decorators (#194) * Emit correct type stubs for async functions wrapped with additional decorators * Address comments * Simplify * Lint --- src/synchronicity/async_wrap.py | 18 ++++++++++++++++++ src/synchronicity/synchronizer.py | 8 +++----- src/synchronicity/type_stubs.py | 4 +++- test/type_stub_test.py | 10 ++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/synchronicity/async_wrap.py b/src/synchronicity/async_wrap.py index e81cec3..f60a9c1 100644 --- a/src/synchronicity/async_wrap.py +++ b/src/synchronicity/async_wrap.py @@ -38,6 +38,24 @@ async def wrapper(*args, **kwargs): return functools.wraps(func) +def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: + """Determine if func returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" + from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import + + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: + return is_coroutine_function_follow_wrapped(func.__wrapped__) + return inspect.iscoroutinefunction(func) + + +def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: + """Determine if func returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" + from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import + + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: + return is_async_gen_function_follow_wrapped(func.__wrapped__) + return inspect.isasyncgenfunction(func) + + YIELD_TYPE = typing.TypeVar("YIELD_TYPE") SEND_TYPE = typing.TypeVar("SEND_TYPE") diff --git a/src/synchronicity/synchronizer.py b/src/synchronicity/synchronizer.py index e911e72..d671f12 100644 --- a/src/synchronicity/synchronizer.py +++ b/src/synchronicity/synchronizer.py @@ -20,7 +20,7 @@ from synchronicity.annotations import evaluated_annotation from synchronicity.combined_types import FunctionWithAio, MethodWithAio -from .async_wrap import wraps_by_interface +from .async_wrap import is_async_gen_function_follow_wrapped, is_coroutine_function_follow_wrapped, wraps_by_interface from .callback import Callback from .exceptions import UserCodeException, unwrap_coro_exception, wrap_coro_exception from .interface import DEFAULT_CLASS_PREFIX, DEFAULT_FUNCTION_PREFIXES, Interface @@ -79,7 +79,7 @@ def _type_requires_aio_usage(annotation, declaration_module): def should_have_aio_interface(func): # determines if a blocking function gets an .aio attribute with an async interface to the function or not - if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + if is_coroutine_function_follow_wrapped(func) or is_async_gen_function_follow_wrapped(func): return True # check annotations if they contain any async entities that would need an event loop to be translated: # This catches things like vanilla functions returning Coroutines @@ -468,8 +468,6 @@ def _wrap_callable( else: _name = name - is_coroutinefunction = inspect.iscoroutinefunction(f) - @wraps_by_interface(interface, f) def f_wrapped(*args, **kwargs): return_future = kwargs.pop(_RETURN_FUTURE_KWARG, False) @@ -499,7 +497,7 @@ def f_wrapped(*args, **kwargs): elif is_coroutine: if interface == Interface._ASYNC_WITH_BLOCKING_TYPES: coro = self._run_function_async(res, f) - if not is_coroutinefunction: + if not is_coroutine_function_follow_wrapped(f): # If this is a non-async function that returns a coroutine, # then this is the exit point, and we need to unwrap any # wrapped exception here. Otherwise, the exit point is diff --git a/src/synchronicity/type_stubs.py b/src/synchronicity/type_stubs.py index caa3e28..1d3c062 100644 --- a/src/synchronicity/type_stubs.py +++ b/src/synchronicity/type_stubs.py @@ -27,6 +27,7 @@ import synchronicity from synchronicity import combined_types, overload_tracking from synchronicity.annotations import TYPE_CHECKING_OVERRIDES, evaluated_annotation +from synchronicity.async_wrap import is_coroutine_function_follow_wrapped from synchronicity.interface import Interface from synchronicity.synchronizer import ( SYNCHRONIZER_ATTR, @@ -390,6 +391,7 @@ def final_transform_signature(sig): {aio_func_source} {body_indent}{entity_name}: __{entity_name}_spec{parent_type_var_names_spec} """ + return protocol_attr def _prepare_method_generic_type_vars(self, entity, parent_generic_type_vars): @@ -856,7 +858,7 @@ def _get_function_source( maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n" async_prefix = "" - if inspect.iscoroutinefunction(func): + if is_coroutine_function_follow_wrapped(func): # note: async prefix should not be used for annotated abstract/stub *async generators*, # so we don't check for inspect.isasyncgenfunction since they contain no yield keyword, # and would otherwise indicate an awaitable that returns an async generator to static type checkers diff --git a/test/type_stub_test.py b/test/type_stub_test.py index 973cc59..02a5663 100644 --- a/test/type_stub_test.py +++ b/test/type_stub_test.py @@ -178,6 +178,16 @@ def wrapper(extra_arg: int, *args, **kwargs): assert _function_source(wrapper) == "def orig(extra_arg: int, arg: float):\n ...\n" +def test_wrapped_async_func_remains_async(): + async def orig(arg: str): ... + + @functools.wraps(orig) + def wrapper(*args, **kwargs): + return orig(*args, **kwargs) + + assert _function_source(wrapper) == "async def orig(arg: str):\n ...\n" + + class Base: def base_method(self) -> str: return ""