diff --git a/src/synchronicity/async_wrap.py b/src/synchronicity/async_wrap.py index 64ba0d8..27f92b0 100644 --- a/src/synchronicity/async_wrap.py +++ b/src/synchronicity/async_wrap.py @@ -39,7 +39,7 @@ async def wrapper(*args, **kwargs): def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: - """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicitiy interace.""" + """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: @@ -47,6 +47,16 @@ def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: return inspect.iscoroutinefunction(func) +def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: + """Determine if a function returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" + from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import + + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: + return is_async_gen_function_follow_wrapped(func.__wrapped__) + return inspect.isasyncgenfunction(func) + + + YIELD_TYPE = typing.TypeVar("YIELD_TYPE") SEND_TYPE = typing.TypeVar("SEND_TYPE") diff --git a/src/synchronicity/synchronizer.py b/src/synchronicity/synchronizer.py index 79080ea..d671f12 100644 --- a/src/synchronicity/synchronizer.py +++ b/src/synchronicity/synchronizer.py @@ -20,7 +20,7 @@ from synchronicity.annotations import evaluated_annotation from synchronicity.combined_types import FunctionWithAio, MethodWithAio -from .async_wrap import wraps_by_interface +from .async_wrap import is_async_gen_function_follow_wrapped, is_coroutine_function_follow_wrapped, wraps_by_interface from .callback import Callback from .exceptions import UserCodeException, unwrap_coro_exception, wrap_coro_exception from .interface import DEFAULT_CLASS_PREFIX, DEFAULT_FUNCTION_PREFIXES, Interface @@ -56,18 +56,6 @@ ) -def iscoroutinefunction_follow_wrapped(func): - if hasattr(func, "__wrapped__"): - return iscoroutinefunction_follow_wrapped(func.__wrapped__) - return inspect.iscoroutinefunction(func) - - -def isasyncgenfunction_follow_wrapped(func): - if hasattr(func, "__wrapped__"): - return isasyncgenfunction_follow_wrapped(func.__wrapped__) - return inspect.isasyncgenfunction(func) - - def _type_requires_aio_usage(annotation, declaration_module): if isinstance(annotation, ForwardRef): annotation = annotation.__forward_arg__ @@ -91,7 +79,7 @@ def _type_requires_aio_usage(annotation, declaration_module): def should_have_aio_interface(func): # determines if a blocking function gets an .aio attribute with an async interface to the function or not - if iscoroutinefunction_follow_wrapped(func) or isasyncgenfunction_follow_wrapped(func): + if is_coroutine_function_follow_wrapped(func) or is_async_gen_function_follow_wrapped(func): return True # check annotations if they contain any async entities that would need an event loop to be translated: # This catches things like vanilla functions returning Coroutines @@ -480,8 +468,6 @@ def _wrap_callable( else: _name = name - is_coroutinefunction = iscoroutinefunction_follow_wrapped(f) - @wraps_by_interface(interface, f) def f_wrapped(*args, **kwargs): return_future = kwargs.pop(_RETURN_FUTURE_KWARG, False) @@ -511,7 +497,7 @@ def f_wrapped(*args, **kwargs): elif is_coroutine: if interface == Interface._ASYNC_WITH_BLOCKING_TYPES: coro = self._run_function_async(res, f) - if not is_coroutinefunction: + if not is_coroutine_function_follow_wrapped(f): # If this is a non-async function that returns a coroutine, # then this is the exit point, and we need to unwrap any # wrapped exception here. Otherwise, the exit point is diff --git a/src/synchronicity/type_stubs.py b/src/synchronicity/type_stubs.py index d9e1e93..1d3c062 100644 --- a/src/synchronicity/type_stubs.py +++ b/src/synchronicity/type_stubs.py @@ -857,11 +857,6 @@ def _get_function_source( self.imports.add("typing_extensions") maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n" - def is_async(func): - if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: - return is_async(func.__wrapped__) - return inspect.iscoroutinefunction(func) - async_prefix = "" if is_coroutine_function_follow_wrapped(func): # note: async prefix should not be used for annotated abstract/stub *async generators*,