From f361721a7af42f1841b7ead0eba707de803bd4e4 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Sat, 17 Feb 2024 23:40:13 +0100 Subject: [PATCH] support staticmethod/classmethod/other callables --- src/dishka/dependency_source.py | 115 ++++++++++++++++++++++++-------- tests/unit/test_provider.py | 36 ++++++++++ 2 files changed, 124 insertions(+), 27 deletions(-) diff --git a/src/dishka/dependency_source.py b/src/dishka/dependency_source.py index b03800ea..77e8bca3 100644 --- a/src/dishka/dependency_source.py +++ b/src/dishka/dependency_source.py @@ -1,4 +1,11 @@ -from collections.abc import AsyncIterable, Iterable +from collections.abc import ( + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Generator, + Iterable, + Iterator, +) from enum import Enum from inspect import ( isasyncgenfunction, @@ -101,6 +108,45 @@ def _get_init_members(tp) -> MembersStorage[str, None]: ) +def _guess_factory_type(source): + if isasyncgenfunction(source): + return FactoryType.ASYNC_GENERATOR + elif isgeneratorfunction(source): + return FactoryType.GENERATOR + elif iscoroutinefunction(source): + return FactoryType.ASYNC_FACTORY + else: + return FactoryType.FACTORY + + +def _clean_result_hint(factory_type: FactoryType, possible_dependency: Any): + if factory_type == FactoryType.ASYNC_GENERATOR: + origin = get_origin(possible_dependency) + if origin is AsyncIterable: + return get_args(possible_dependency)[0] + elif origin is AsyncIterator: + return get_args(possible_dependency)[0] + elif origin is AsyncGenerator: + return get_args(possible_dependency)[0] + else: + raise TypeError( + f"Unsupported return type {possible_dependency} {origin} " + f"for async generator") + elif factory_type == FactoryType.GENERATOR: + origin = get_origin(possible_dependency) + if origin is Iterable: + return get_args(possible_dependency)[0] + elif origin is Iterator: + return get_args(possible_dependency)[0] + elif origin is Generator: + return get_args(possible_dependency)[1] + else: + raise TypeError( + f"Unsupported return type {possible_dependency} {origin}" + f" for generator") + return possible_dependency + + def make_factory( provides: Any, scope: Optional[BaseScope], @@ -109,47 +155,62 @@ def make_factory( ) -> Factory: if is_bare_generic(source): source = source[get_type_vars(source)] + if isclass(source) or get_origin(source): # we need to fix concrete generics and normal classes as well # as classes can be children of concrete generics res = GenericResolver(_get_init_members) hints = dict(res.get_resolved_members(source).members) hints.pop("return", None) - possible_dependency = source + dependencies = list(hints.values()) + if not provides: + provides = source is_to_bind = False - elif isfunction(source): - params = signature(source).parameters + factory_type = FactoryType.FACTORY + elif isfunction(source) or isinstance(source, classmethod): + if isinstance(source, classmethod): + params = signature(source.__wrapped__).parameters + factory_type = _guess_factory_type(source.__wrapped__) + else: + params = signature(source).parameters + factory_type = _guess_factory_type(source) + self = next(iter(params.values())) hints = get_type_hints(source, include_extras=True) hints.pop(self.name, None) possible_dependency = hints.pop("return", None) + dependencies = list(hints.values()) + if not provides: + provides = _clean_result_hint(factory_type, possible_dependency) is_to_bind = True + elif isinstance(source, staticmethod): + factory_type = _guess_factory_type(source.__wrapped__) + hints = get_type_hints(source, include_extras=True) + possible_dependency = hints.pop("return", None) + dependencies = list(hints.values()) + if not provides: + provides = _clean_result_hint(factory_type, possible_dependency) + is_to_bind = False + elif callable(source): + factory = make_factory( + provides=provides, + source=type(source).__call__, + cache=cache, + scope=scope, + ) + factory_type = factory.type + dependencies = factory.dependencies + provides = factory.provides + is_to_bind = False else: raise TypeError(f"Cannot use {type(source)} as a factory") - if isasyncgenfunction(source): - provider_type = FactoryType.ASYNC_GENERATOR - if get_origin(possible_dependency) is AsyncIterable: - possible_dependency = get_args(possible_dependency)[0] - else: # async generator - possible_dependency = get_args(possible_dependency)[0] - elif isgeneratorfunction(source): - provider_type = FactoryType.GENERATOR - if get_origin(possible_dependency) is Iterable: - possible_dependency = get_args(possible_dependency)[0] - else: # generator - possible_dependency = get_args(possible_dependency)[1] - elif iscoroutinefunction(source): - provider_type = FactoryType.ASYNC_FACTORY - else: - provider_type = FactoryType.FACTORY - return Factory( - dependencies=list(hints.values()), - type=provider_type, + dependencies=dependencies, + type=factory_type, source=source, scope=scope, - provides=provides or possible_dependency, + provides=provides, is_to_bound=is_to_bind, cache=cache, ) @@ -167,7 +228,7 @@ def provide( @overload def provide( - source: Callable | Type, + source: Callable | classmethod | staticmethod | Type | None, *, scope: BaseScope, provides: Any = None, @@ -177,7 +238,7 @@ def provide( def provide( - source: Callable | Type | None = None, + source: Callable | classmethod | staticmethod | Type | None = None, *, scope: BaseScope | None = None, provides: Any = None, @@ -239,7 +300,7 @@ def alias( *, source: Type, provides: Type, - cache: bool=True, + cache: bool = True, ) -> Alias: return Alias( source=source, diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index c02647e5..1803e059 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -87,3 +87,39 @@ def foo(self: Provider) -> str: provider = MyProvider(scope=Scope.REQUEST) assert not provider.foo.dependencies + + +def test_staticmethod(): + class MyProvider(Provider): + @provide + @staticmethod + def foo() -> str: + return "hello" + + provider = MyProvider(scope=Scope.REQUEST) + assert not provider.foo.dependencies + + +def test_classmethod(): + class MyProvider(Provider): + @provide + @classmethod + def foo(cls: type) -> str: + return "hello" + + provider = MyProvider(scope=Scope.REQUEST) + assert not provider.foo.dependencies + + +class MyCallable: + def __call__(self: object, param: int) -> str: + return "hello" + + +def test_callable(): + class MyProvider(Provider): + foo = provide(MyCallable()) + + provider = MyProvider(scope=Scope.REQUEST) + assert provider.foo.provides == str + assert provider.foo.dependencies == [int]