diff --git a/a_sync/decorator.py b/a_sync/decorator.py index 80b5f55f..6b505482 100644 --- a/a_sync/decorator.py +++ b/a_sync/decorator.py @@ -1,6 +1,6 @@ import functools -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Literal, Optional from typing_extensions import ParamSpec # type: ignore @@ -10,25 +10,36 @@ T = TypeVar("T") -def a_sync(coro_fn: Callable[P, T]) -> Callable[P, T]: # type: ignore +def a_sync(default: Optional[Literal['sync','async']] = None) -> Callable[[Callable[P, T]], Callable[P, T]]: f""" A coroutine function decorated with this decorator can be called as a sync function by passing a boolean value for any of these kwargs: {_helpers._flag_name_options} """ + if default not in ['async', 'sync', None]: + if callable(default): + return a_sync()(default) + raise ValueError(f"'default' must be either 'sync', 'async', or None. You passed {default}.") - _helpers._validate_wrapped_fn(coro_fn) + def a_sync_deco(coro_fn: Callable[P, T]) -> Callable[P, T]: # type: ignore + + _helpers._validate_wrapped_fn(coro_fn) - @functools.wraps(coro_fn) - def a_sync_wrap(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore - # If a flag was specified in the kwargs, we will defer to it. - for flag in _helpers._flag_name_options: - if flag in kwargs: - val = kwargs.pop(flag) - if not isinstance(val, bool): - raise TypeError(f"'{flag}' must be boolean. You passed {val}.") - return _helpers._await_if_sync( # type: ignore - coro_fn(*args, **kwargs), - val if flag == 'sync' else not val - ) - # No flag specified in the kwargs, we will just return the awaitable. - return coro_fn(*args, **kwargs) - return a_sync_wrap + @functools.wraps(coro_fn) + def a_sync_wrap(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore + # If a flag was specified in the kwargs, we will defer to it. + for flag in _helpers._flag_name_options: + if flag in kwargs: + val = kwargs.pop(flag) + if not isinstance(val, bool): + raise TypeError(f"'{flag}' must be boolean. You passed {val}.") + return _helpers._await_if_sync( + coro_fn(*args, **kwargs), + val if flag == 'sync' else not val + ) + + # No flag specified in the kwargs, we will defer to 'default'. + return _helpers._await_if_sync( + coro_fn(*args, **kwargs), + True if default == 'sync' else False + ) + return a_sync_wrap + return a_sync_deco diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 76984946..944ce5bb 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,13 +1,86 @@ import asyncio +import pytest +from typing import Literal import a_sync -@a_sync.a_sync -async def some_test_fn() -> int: - return 2 -def test_decorator(): +def _test_kwargs(fn, default: Literal['sync','async',None]): + # force async + assert asyncio.get_event_loop().run_until_complete(fn(sync=False)) == 2 + assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=True)) == 2 + # force sync + with pytest.raises(TypeError): + assert asyncio.get_event_loop().run_until_complete(fn(sync=True)) == 2 + with pytest.raises(TypeError): + assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=False)) == 2 + assert fn(sync=True) == 2 + assert fn(asynchronous=False) == 2 + if default == 'sync': + assert fn() == 2 + elif default == 'async': + assert asyncio.get_event_loop().run_until_complete(fn()) == 2 + elif default is None: + assert asyncio.get_event_loop().run_until_complete(fn()) == 2 + else: + raise NotImplementedError(default) + +def test_decorator_no_args(): + @a_sync.a_sync + async def some_test_fn() -> int: + return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 + _test_kwargs(some_test_fn, None) + + @a_sync.a_sync() + async def some_test_fn() -> int: + return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 + _test_kwargs(some_test_fn, None) + +def test_decorator_default_none_arg(): + @a_sync.a_sync(None) + async def some_test_fn() -> int: + return 2 + asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 + _test_kwargs(some_test_fn, None) + +def test_decorator_default_none_kwarg(): + @a_sync.a_sync(default=None) + async def some_test_fn() -> int: + return 2 + asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 + _test_kwargs(some_test_fn, None) + +def test_decorator_default_sync_arg(): + @a_sync.a_sync('sync') + async def some_test_fn() -> int: + return 2 + with pytest.raises(TypeError): + asyncio.get_event_loop().run_until_complete(some_test_fn()) + assert some_test_fn() == 2 + _test_kwargs(some_test_fn, 'sync') + +def test_decorator_default_sync_kwarg(): + @a_sync.a_sync(default='sync') + async def some_test_fn() -> int: + return 2 + with pytest.raises(TypeError): + asyncio.get_event_loop().run_until_complete(some_test_fn()) + assert some_test_fn() == 2 + _test_kwargs(some_test_fn, 'sync') + +def test_decorator_default_async_arg(): + @a_sync.a_sync('async') + async def some_test_fn() -> int: + return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 + _test_kwargs(some_test_fn, 'async') + +def test_decorator_default_async_kwarg(): + @a_sync.a_sync(default='async') + async def some_test_fn() -> int: + return 2 assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - assert some_test_fn(sync=True) == 2 - assert some_test_fn(asynchronous=False) == 2 + _test_kwargs(some_test_fn, 'async')