diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 5d062004..32e41895 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -2,13 +2,18 @@ from typing import ( Annotated, Any, + Awaitable, Callable, + Literal, Sequence, + Type, get_args, get_origin, get_type_hints, + overload, ) +from dishka.async_container import AsyncContainer from dishka.container import Container @@ -20,7 +25,7 @@ def __init__(self, param: Any = None): def default_parse_dependency( parameter: Parameter, hint: Any, - depends_class: Any = Depends, + depends_class: Type[Any] = Depends, ) -> Any: """ Resolve dependency type or return None if it is not a dependency """ if get_origin(hint) is not Annotated: @@ -40,14 +45,41 @@ def default_parse_dependency( DependencyParser = Callable[[Parameter, Any], Any] +@overload def wrap_injection( + *, func: Callable, container_getter: Callable[[tuple, dict], Container], + is_async: Literal[False] = False, remove_depends: bool = True, additional_params: Sequence[Parameter] = (), - is_async: bool = False, parse_dependency: DependencyParser = default_parse_dependency, ) -> Callable: + ... + + +@overload +def wrap_injection( + *, + func: Callable, + container_getter: Callable[[tuple, dict], AsyncContainer], + is_async: Literal[True], + remove_depends: bool = True, + additional_params: Sequence[Parameter] = (), + parse_dependency: DependencyParser = default_parse_dependency, +) -> Awaitable: + ... + + +def wrap_injection( + *, + func: Callable, + container_getter, + is_async: bool = False, + remove_depends: bool = True, + additional_params: Sequence[Parameter] = (), + parse_dependency: DependencyParser = default_parse_dependency, +): hints = get_type_hints(func, include_extras=True) func_signature = signature(func)