Skip to content

Commit

Permalink
feat: default kwarg (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Feb 26, 2023
1 parent eaadbbc commit da1c378
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 24 deletions.
47 changes: 29 additions & 18 deletions a_sync/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import functools
from typing import Callable, TypeVar
from typing import Callable, TypeVar, Literal, Optional

from typing_extensions import ParamSpec # type: ignore

Expand All @@ -10,25 +10,36 @@
T = TypeVar("T")


def a_sync(coro_fn: Callable[P, T]) -> Callable[P, T]: # type: ignore
def a_sync(default: Optional[Literal['sync','async']] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
f"""
A coroutine function decorated with this decorator can be called as a sync function by passing a boolean value for any of these kwargs: {_helpers._flag_name_options}
"""
if default not in ['async', 'sync', None]:
if callable(default):
return a_sync()(default)
raise ValueError(f"'default' must be either 'sync', 'async', or None. You passed {default}.")

_helpers._validate_wrapped_fn(coro_fn)
def a_sync_deco(coro_fn: Callable[P, T]) -> Callable[P, T]: # type: ignore

_helpers._validate_wrapped_fn(coro_fn)

@functools.wraps(coro_fn)
def a_sync_wrap(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore
# If a flag was specified in the kwargs, we will defer to it.
for flag in _helpers._flag_name_options:
if flag in kwargs:
val = kwargs.pop(flag)
if not isinstance(val, bool):
raise TypeError(f"'{flag}' must be boolean. You passed {val}.")
return _helpers._await_if_sync( # type: ignore
coro_fn(*args, **kwargs),
val if flag == 'sync' else not val
)
# No flag specified in the kwargs, we will just return the awaitable.
return coro_fn(*args, **kwargs)
return a_sync_wrap
@functools.wraps(coro_fn)
def a_sync_wrap(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore
# If a flag was specified in the kwargs, we will defer to it.
for flag in _helpers._flag_name_options:
if flag in kwargs:
val = kwargs.pop(flag)
if not isinstance(val, bool):
raise TypeError(f"'{flag}' must be boolean. You passed {val}.")
return _helpers._await_if_sync(
coro_fn(*args, **kwargs),
val if flag == 'sync' else not val
)

# No flag specified in the kwargs, we will defer to 'default'.
return _helpers._await_if_sync(
coro_fn(*args, **kwargs),
True if default == 'sync' else False
)
return a_sync_wrap
return a_sync_deco
85 changes: 79 additions & 6 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,86 @@

import asyncio
import pytest
from typing import Literal

import a_sync

@a_sync.a_sync
async def some_test_fn() -> int:
return 2

def test_decorator():
def _test_kwargs(fn, default: Literal['sync','async',None]):
# force async
assert asyncio.get_event_loop().run_until_complete(fn(sync=False)) == 2
assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=True)) == 2
# force sync
with pytest.raises(TypeError):
assert asyncio.get_event_loop().run_until_complete(fn(sync=True)) == 2
with pytest.raises(TypeError):
assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=False)) == 2
assert fn(sync=True) == 2
assert fn(asynchronous=False) == 2
if default == 'sync':
assert fn() == 2
elif default == 'async':
assert asyncio.get_event_loop().run_until_complete(fn()) == 2
elif default is None:
assert asyncio.get_event_loop().run_until_complete(fn()) == 2
else:
raise NotImplementedError(default)

def test_decorator_no_args():
@a_sync.a_sync
async def some_test_fn() -> int:
return 2
assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
_test_kwargs(some_test_fn, None)

@a_sync.a_sync()
async def some_test_fn() -> int:
return 2
assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
_test_kwargs(some_test_fn, None)

def test_decorator_default_none_arg():
@a_sync.a_sync(None)
async def some_test_fn() -> int:
return 2
asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
_test_kwargs(some_test_fn, None)

def test_decorator_default_none_kwarg():
@a_sync.a_sync(default=None)
async def some_test_fn() -> int:
return 2
asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
_test_kwargs(some_test_fn, None)

def test_decorator_default_sync_arg():
@a_sync.a_sync('sync')
async def some_test_fn() -> int:
return 2
with pytest.raises(TypeError):
asyncio.get_event_loop().run_until_complete(some_test_fn())
assert some_test_fn() == 2
_test_kwargs(some_test_fn, 'sync')

def test_decorator_default_sync_kwarg():
@a_sync.a_sync(default='sync')
async def some_test_fn() -> int:
return 2
with pytest.raises(TypeError):
asyncio.get_event_loop().run_until_complete(some_test_fn())
assert some_test_fn() == 2
_test_kwargs(some_test_fn, 'sync')

def test_decorator_default_async_arg():
@a_sync.a_sync('async')
async def some_test_fn() -> int:
return 2
assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
_test_kwargs(some_test_fn, 'async')

def test_decorator_default_async_kwarg():
@a_sync.a_sync(default='async')
async def some_test_fn() -> int:
return 2
assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2
assert some_test_fn(sync=True) == 2
assert some_test_fn(asynchronous=False) == 2
_test_kwargs(some_test_fn, 'async')

0 comments on commit da1c378

Please sign in to comment.