Skip to content

Commit

Permalink
support staticmethod/classmethod/other callables
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Feb 17, 2024
1 parent 3736a39 commit f361721
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 27 deletions.
115 changes: 88 additions & 27 deletions src/dishka/dependency_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from collections.abc import AsyncIterable, Iterable
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Generator,
Iterable,
Iterator,
)
from enum import Enum
from inspect import (
isasyncgenfunction,
Expand Down Expand Up @@ -101,6 +108,45 @@ def _get_init_members(tp) -> MembersStorage[str, None]:
)


def _guess_factory_type(source):
if isasyncgenfunction(source):
return FactoryType.ASYNC_GENERATOR
elif isgeneratorfunction(source):
return FactoryType.GENERATOR
elif iscoroutinefunction(source):
return FactoryType.ASYNC_FACTORY
else:
return FactoryType.FACTORY


def _clean_result_hint(factory_type: FactoryType, possible_dependency: Any):
if factory_type == FactoryType.ASYNC_GENERATOR:
origin = get_origin(possible_dependency)
if origin is AsyncIterable:
return get_args(possible_dependency)[0]
elif origin is AsyncIterator:
return get_args(possible_dependency)[0]
elif origin is AsyncGenerator:
return get_args(possible_dependency)[0]
else:
raise TypeError(
f"Unsupported return type {possible_dependency} {origin} "
f"for async generator")
elif factory_type == FactoryType.GENERATOR:
origin = get_origin(possible_dependency)
if origin is Iterable:
return get_args(possible_dependency)[0]
elif origin is Iterator:
return get_args(possible_dependency)[0]
elif origin is Generator:
return get_args(possible_dependency)[1]
else:
raise TypeError(
f"Unsupported return type {possible_dependency} {origin}"
f" for generator")
return possible_dependency


def make_factory(
provides: Any,
scope: Optional[BaseScope],
Expand All @@ -109,47 +155,62 @@ def make_factory(
) -> Factory:
if is_bare_generic(source):
source = source[get_type_vars(source)]

if isclass(source) or get_origin(source):
# we need to fix concrete generics and normal classes as well
# as classes can be children of concrete generics
res = GenericResolver(_get_init_members)
hints = dict(res.get_resolved_members(source).members)
hints.pop("return", None)
possible_dependency = source
dependencies = list(hints.values())
if not provides:
provides = source
is_to_bind = False
elif isfunction(source):
params = signature(source).parameters
factory_type = FactoryType.FACTORY
elif isfunction(source) or isinstance(source, classmethod):
if isinstance(source, classmethod):
params = signature(source.__wrapped__).parameters
factory_type = _guess_factory_type(source.__wrapped__)
else:
params = signature(source).parameters
factory_type = _guess_factory_type(source)

self = next(iter(params.values()))
hints = get_type_hints(source, include_extras=True)
hints.pop(self.name, None)
possible_dependency = hints.pop("return", None)
dependencies = list(hints.values())
if not provides:
provides = _clean_result_hint(factory_type, possible_dependency)
is_to_bind = True
elif isinstance(source, staticmethod):
factory_type = _guess_factory_type(source.__wrapped__)
hints = get_type_hints(source, include_extras=True)
possible_dependency = hints.pop("return", None)
dependencies = list(hints.values())
if not provides:
provides = _clean_result_hint(factory_type, possible_dependency)
is_to_bind = False
elif callable(source):
factory = make_factory(
provides=provides,
source=type(source).__call__,
cache=cache,
scope=scope,
)
factory_type = factory.type
dependencies = factory.dependencies
provides = factory.provides
is_to_bind = False
else:
raise TypeError(f"Cannot use {type(source)} as a factory")

if isasyncgenfunction(source):
provider_type = FactoryType.ASYNC_GENERATOR
if get_origin(possible_dependency) is AsyncIterable:
possible_dependency = get_args(possible_dependency)[0]
else: # async generator
possible_dependency = get_args(possible_dependency)[0]
elif isgeneratorfunction(source):
provider_type = FactoryType.GENERATOR
if get_origin(possible_dependency) is Iterable:
possible_dependency = get_args(possible_dependency)[0]
else: # generator
possible_dependency = get_args(possible_dependency)[1]
elif iscoroutinefunction(source):
provider_type = FactoryType.ASYNC_FACTORY
else:
provider_type = FactoryType.FACTORY

return Factory(
dependencies=list(hints.values()),
type=provider_type,
dependencies=dependencies,
type=factory_type,
source=source,
scope=scope,
provides=provides or possible_dependency,
provides=provides,
is_to_bound=is_to_bind,
cache=cache,
)
Expand All @@ -167,7 +228,7 @@ def provide(

@overload
def provide(
source: Callable | Type,
source: Callable | classmethod | staticmethod | Type | None,
*,
scope: BaseScope,
provides: Any = None,
Expand All @@ -177,7 +238,7 @@ def provide(


def provide(
source: Callable | Type | None = None,
source: Callable | classmethod | staticmethod | Type | None = None,
*,
scope: BaseScope | None = None,
provides: Any = None,
Expand Down Expand Up @@ -239,7 +300,7 @@ def alias(
*,
source: Type,
provides: Type,
cache: bool=True,
cache: bool = True,
) -> Alias:
return Alias(
source=source,
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,39 @@ def foo(self: Provider) -> str:

provider = MyProvider(scope=Scope.REQUEST)
assert not provider.foo.dependencies


def test_staticmethod():
class MyProvider(Provider):
@provide
@staticmethod
def foo() -> str:
return "hello"

provider = MyProvider(scope=Scope.REQUEST)
assert not provider.foo.dependencies


def test_classmethod():
class MyProvider(Provider):
@provide
@classmethod
def foo(cls: type) -> str:
return "hello"

provider = MyProvider(scope=Scope.REQUEST)
assert not provider.foo.dependencies


class MyCallable:
def __call__(self: object, param: int) -> str:
return "hello"


def test_callable():
class MyProvider(Provider):
foo = provide(MyCallable())

provider = MyProvider(scope=Scope.REQUEST)
assert provider.foo.provides == str
assert provider.foo.dependencies == [int]

0 comments on commit f361721

Please sign in to comment.