diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index 8e498031..ee7274ec 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -43,7 +43,7 @@ def __init__( self.lock = None self.exits: List[Exit] = [] - def _get_child( + def _create_child( self, context: Optional[dict], with_lock: bool, @@ -68,35 +68,29 @@ def __call__( """ if not self.child_registries: raise ValueError("No child scopes found") - return AsyncContextWrapper(self._get_child(context, with_lock)) + return AsyncContextWrapper(self._create_child(context, with_lock)) - async def _get_parent(self, dependency_type: Type[T]) -> T: - return await self.parent_container.get(dependency_type) - - async def _get_self( - self, - dep_provider: Factory, - ) -> T: + async def _get_from_self(self, factory: Factory) -> T: sub_dependencies = [ await self._get_unlocked(dependency) - for dependency in dep_provider.dependencies + for dependency in factory.dependencies ] - if dep_provider.type is FactoryType.GENERATOR: - generator = dep_provider.source(*sub_dependencies) - self.exits.append(Exit(dep_provider.type, generator)) + if factory.type is FactoryType.GENERATOR: + generator = factory.source(*sub_dependencies) + self.exits.append(Exit(factory.type, generator)) return next(generator) - elif dep_provider.type is FactoryType.ASYNC_GENERATOR: - generator = dep_provider.source(*sub_dependencies) - self.exits.append(Exit(dep_provider.type, generator)) + elif factory.type is FactoryType.ASYNC_GENERATOR: + generator = factory.source(*sub_dependencies) + self.exits.append(Exit(factory.type, generator)) return await anext(generator) - elif dep_provider.type is FactoryType.ASYNC_FACTORY: - return await dep_provider.source(*sub_dependencies) - elif dep_provider.type is FactoryType.FACTORY: - return dep_provider.source(*sub_dependencies) - elif dep_provider.type is FactoryType.VALUE: - return dep_provider.source + elif factory.type is FactoryType.ASYNC_FACTORY: + return await factory.source(*sub_dependencies) + elif factory.type is FactoryType.FACTORY: + return factory.source(*sub_dependencies) + elif factory.type is FactoryType.VALUE: + return factory.source else: - raise ValueError(f"Unsupported type {dep_provider.type}") + raise ValueError(f"Unsupported type {factory.type}") async def get(self, dependency_type: Type[T]) -> T: lock = self.lock @@ -113,7 +107,7 @@ async def _get_unlocked(self, dependency_type: Type[T]) -> T: if not self.parent_container: raise ValueError(f"No provider found for {dependency_type!r}") return await self.parent_container.get(dependency_type) - solved = await self._get_self(provider) + solved = await self._get_from_self(provider) self.context[dependency_type] = solved return solved diff --git a/src/dishka/container.py b/src/dishka/container.py index 75869acc..65adbc8e 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -43,7 +43,7 @@ def __init__( self.lock = None self.exits: List[Exit] = [] - def _get_child( + def _create_child( self, context: Optional[dict], with_lock: bool, @@ -68,29 +68,23 @@ def __call__( """ if not self.child_registries: raise ValueError("No child scopes found") - return ContextWrapper(self._get_child(context, with_lock)) + return ContextWrapper(self._create_child(context, with_lock)) - def _get_parent(self, dependency_type: Type[T]) -> T: - return self.parent_container.get(dependency_type) - - def _get_self( - self, - dep_provider: Factory, - ) -> T: + def _get_from_self(self, factory: Factory) -> T: sub_dependencies = [ self._get_unlocked(dependency) - for dependency in dep_provider.dependencies + for dependency in factory.dependencies ] - if dep_provider.type is FactoryType.GENERATOR: - generator = dep_provider.source(*sub_dependencies) - self.exits.append(Exit(dep_provider.type, generator)) + if factory.type is FactoryType.GENERATOR: + generator = factory.source(*sub_dependencies) + self.exits.append(Exit(factory.type, generator)) return next(generator) - elif dep_provider.type is FactoryType.FACTORY: - return dep_provider.source(*sub_dependencies) - elif dep_provider.type is FactoryType.VALUE: - return dep_provider.source + elif factory.type is FactoryType.FACTORY: + return factory.source(*sub_dependencies) + elif factory.type is FactoryType.VALUE: + return factory.source else: - raise ValueError(f"Unsupported type {dep_provider.type}") + raise ValueError(f"Unsupported type {factory.type}") def get(self, dependency_type: Type[T]) -> T: lock = self.lock @@ -107,7 +101,7 @@ def _get_unlocked(self, dependency_type: Type[T]) -> T: if not self.parent_container: raise ValueError(f"No provider found for {dependency_type!r}") return self.parent_container.get(dependency_type) - solved = self._get_self(provider) + solved = self._get_from_self(provider) self.context[dependency_type] = solved return solved diff --git a/src/dishka/provider.py b/src/dishka/provider.py index 372a380a..685072c4 100644 --- a/src/dishka/provider.py +++ b/src/dishka/provider.py @@ -1,6 +1,11 @@ -from typing import List +import inspect +from typing import Any, List -from .dependency_source import DependencySource +from .dependency_source import Alias, Decorator, DependencySource, Factory + + +def is_dependency_source(attribute: Any) -> bool: + return isinstance(attribute, DependencySource) class Provider: @@ -18,8 +23,26 @@ class Provider: """ def __init__(self): - self.dependency_sources: List[DependencySource] = [ - getattr(self, name) - for name, attr in vars(type(self)).items() - if isinstance(attr, DependencySource) - ] + self.factories: List[Factory] = [] + self.aliases: List[Alias] = [] + self.decorators: List[Decorator] = [] + self._init_dependency_sources() + + def _init_dependency_sources(self) -> None: + processed_types = {} + + source: DependencySource + for name, source in inspect.getmembers(self, is_dependency_source): + if source.provides in processed_types: + raise ValueError( + f"Type {source.provides} is registered multiple times " + f"in the same {Provider} by attributes " + f"{processed_types[source.provides]!r} and {name!r}", + ) + if isinstance(source, Alias): + self.aliases.append(source) + if isinstance(source, Factory): + self.factories.append(source) + if isinstance(source, Decorator): + self.decorators.append(source) + processed_types[source.provides] = name diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 40c77932..f92d6168 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import Any, List, NewType, Type -from .dependency_source import Alias, Decorator, Factory +from .dependency_source import Factory from .provider import Provider from .scope import BaseScope @@ -23,40 +23,48 @@ def get_provider(self, dependency: Any) -> Factory: def make_registries( *providers: Provider, scopes: Type[BaseScope], ) -> List[Registry]: - dep_scopes = {} + dep_scopes: dict[Type, BaseScope] = {} + alias_sources = {} for provider in providers: - for source in provider.dependency_sources: - if hasattr(source, "scope"): - dep_scopes[source.provides] = source.scope + for source in provider.factories: + dep_scopes[source.provides] = source.scope + for source in provider.aliases: + alias_sources[source.provides] = source.source registries = {scope: Registry(scope) for scope in scopes} decorator_depth: dict[Type, int] = defaultdict(int) for provider in providers: - for source in provider.dependency_sources: + for source in provider.factories: + scope = source.scope + registries[scope].add_provider(source) + for source in provider.aliases: + alias_source = source.source + visited_types = [alias_source] + while alias_source not in dep_scopes: + alias_source = alias_sources[alias_source] + if alias_source in visited_types: + raise ValueError(f"Cycle aliases detected {visited_types}") + visited_types.append(alias_source) + scope = dep_scopes[alias_source] + dep_scopes[source.provides] = scope + source = source.as_factory(scope) + registries[scope].add_provider(source) + for source in provider.decorators: provides = source.provides - if isinstance(source, Factory): - scope = source.scope - elif isinstance(source, Alias): - scope = dep_scopes[source.source] - dep_scopes[provides] = scope - source = source.as_factory(scope) - elif isinstance(source, Decorator): - scope = dep_scopes[provides] - registry = registries[scope] - undecorated_type = NewType( - f"{provides.__name__}@{decorator_depth[provides]}", - source.provides, - ) - decorator_depth[provides] += 1 - old_provider = registry.get_provider(provides) - old_provider.provides = undecorated_type - registry.add_provider(old_provider) - source = source.as_factory( - scope, undecorated_type, - ) - else: - raise ValueError("Unknown dependency source type") + scope = dep_scopes[provides] + registry = registries[scope] + undecorated_type = NewType( + f"{provides.__name__}@{decorator_depth[provides]}", + source.provides, + ) + decorator_depth[provides] += 1 + old_provider = registry.get_provider(provides) + old_provider.provides = undecorated_type + registry.add_provider(old_provider) + source = source.as_factory( + scope, undecorated_type, + ) registries[scope].add_provider(source) return list(registries.values()) diff --git a/tests/container/test_alias.py b/tests/container/test_alias.py new file mode 100644 index 00000000..b91315ad --- /dev/null +++ b/tests/container/test_alias.py @@ -0,0 +1,33 @@ +import pytest + +from dishka import Provider, Scope, alias, make_container, provide + + +class AliasProvider(Provider): + @provide(scope=Scope.APP) + def provide_int(self) -> int: + return 42 + + aliased_complex = alias(source=float, provides=complex) + aliased_float = alias(source=int, provides=float) + + +def test_alias(): + with make_container(AliasProvider()) as container: + assert container.get(float) == container.get(int) + + +def test_alias_to_alias(): + with make_container(AliasProvider()) as container: + assert container.get(complex) == container.get(int) + + +class CycleProvider(Provider): + a = alias(source=int, provides=bool) + b = alias(source=bool, provides=float) + c = alias(source=float, provides=int) + + +def test_cycle(): + with pytest.raises(ValueError): + make_container(CycleProvider()) diff --git a/tests/container/test_decorator.py b/tests/container/test_decorator.py index 0598ba03..4ed84de1 100644 --- a/tests/container/test_decorator.py +++ b/tests/container/test_decorator.py @@ -1,3 +1,5 @@ +import pytest + from dishka import Provider, Scope, alias, decorate, make_container, provide @@ -21,9 +23,11 @@ def __init__(self, a: A): def test_simple(): class MyProvider(Provider): a = provide(A, scope=Scope.APP) + + class DProvider(Provider): ad = decorate(ADecorator, provides=A) - with make_container(MyProvider()) as container: + with make_container(MyProvider(), DProvider()) as container: a = container.get(A) assert isinstance(a, ADecorator) assert isinstance(a.a, A) @@ -35,11 +39,12 @@ class MyProvider(Provider): a1 = alias(source=A2, provides=A1) a = alias(source=A1, provides=A) + class DProvider(Provider): @decorate def decorated(self, a: A1) -> A1: return ADecorator(a) - with make_container(MyProvider()) as container: + with make_container(MyProvider(), DProvider()) as container: a1 = container.get(A1) assert isinstance(a1, ADecorator) assert isinstance(a1.a, A2) @@ -52,13 +57,27 @@ def decorated(self, a: A1) -> A1: assert a is a1 -def test_double(): +def test_double_error(): class MyProvider(Provider): a = provide(A, scope=Scope.APP) ad = decorate(ADecorator, provides=A) ad2 = decorate(ADecorator, provides=A) - with make_container(MyProvider()) as container: + with pytest.raises(ValueError): + MyProvider() + + +def test_double_ok(): + class MyProvider(Provider): + a = provide(A, scope=Scope.APP) + + class DProvider(Provider): + ad = decorate(ADecorator, provides=A) + + class D2Provider(Provider): + ad2 = decorate(ADecorator, provides=A) + + with make_container(MyProvider(), DProvider(), D2Provider()) as container: a = container.get(A) assert isinstance(a, ADecorator) assert isinstance(a.a, ADecorator) diff --git a/tests/test_provider.py b/tests/test_provider.py index 098bb440..578fe024 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -16,14 +16,14 @@ def test_provider_init(): class MyProvider(Provider): a = alias(source=int, provides=bool) - b = provide(lambda: False, scope=Scope.APP, provides=bool) @provide(scope=Scope.REQUEST) def foo(self, x: bool) -> str: return f"{x}" provider = MyProvider() - assert len(provider.dependency_sources) == 3 + assert len(provider.factories) == 1 + assert len(provider.aliases) == 1 @pytest.mark.parametrize(