Skip to content

Commit

Permalink
Merge pull request #22 from reagento/refactor/dependency_sources_list
Browse files Browse the repository at this point in the history
use inspect to collect dependency sources
  • Loading branch information
Tishka17 authored Jan 29, 2024
2 parents e1799aa + 8e571da commit 3a89466
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 84 deletions.
42 changes: 18 additions & 24 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.lock = None
self.exits: List[Exit] = []

def _get_child(
def _create_child(
self,
context: Optional[dict],
with_lock: bool,
Expand All @@ -68,35 +68,29 @@ def __call__(
"""
if not self.child_registries:
raise ValueError("No child scopes found")
return AsyncContextWrapper(self._get_child(context, with_lock))
return AsyncContextWrapper(self._create_child(context, with_lock))

async def _get_parent(self, dependency_type: Type[T]) -> T:
return await self.parent_container.get(dependency_type)

async def _get_self(
self,
dep_provider: Factory,
) -> T:
async def _get_from_self(self, factory: Factory) -> T:
sub_dependencies = [
await self._get_unlocked(dependency)
for dependency in dep_provider.dependencies
for dependency in factory.dependencies
]
if dep_provider.type is FactoryType.GENERATOR:
generator = dep_provider.source(*sub_dependencies)
self.exits.append(Exit(dep_provider.type, generator))
if factory.type is FactoryType.GENERATOR:
generator = factory.source(*sub_dependencies)
self.exits.append(Exit(factory.type, generator))
return next(generator)
elif dep_provider.type is FactoryType.ASYNC_GENERATOR:
generator = dep_provider.source(*sub_dependencies)
self.exits.append(Exit(dep_provider.type, generator))
elif factory.type is FactoryType.ASYNC_GENERATOR:
generator = factory.source(*sub_dependencies)
self.exits.append(Exit(factory.type, generator))
return await anext(generator)
elif dep_provider.type is FactoryType.ASYNC_FACTORY:
return await dep_provider.source(*sub_dependencies)
elif dep_provider.type is FactoryType.FACTORY:
return dep_provider.source(*sub_dependencies)
elif dep_provider.type is FactoryType.VALUE:
return dep_provider.source
elif factory.type is FactoryType.ASYNC_FACTORY:
return await factory.source(*sub_dependencies)
elif factory.type is FactoryType.FACTORY:
return factory.source(*sub_dependencies)
elif factory.type is FactoryType.VALUE:
return factory.source
else:
raise ValueError(f"Unsupported type {dep_provider.type}")
raise ValueError(f"Unsupported type {factory.type}")

async def get(self, dependency_type: Type[T]) -> T:
lock = self.lock
Expand All @@ -113,7 +107,7 @@ async def _get_unlocked(self, dependency_type: Type[T]) -> T:
if not self.parent_container:
raise ValueError(f"No provider found for {dependency_type!r}")
return await self.parent_container.get(dependency_type)
solved = await self._get_self(provider)
solved = await self._get_from_self(provider)
self.context[dependency_type] = solved
return solved

Expand Down
32 changes: 13 additions & 19 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.lock = None
self.exits: List[Exit] = []

def _get_child(
def _create_child(
self,
context: Optional[dict],
with_lock: bool,
Expand All @@ -68,29 +68,23 @@ def __call__(
"""
if not self.child_registries:
raise ValueError("No child scopes found")
return ContextWrapper(self._get_child(context, with_lock))
return ContextWrapper(self._create_child(context, with_lock))

def _get_parent(self, dependency_type: Type[T]) -> T:
return self.parent_container.get(dependency_type)

def _get_self(
self,
dep_provider: Factory,
) -> T:
def _get_from_self(self, factory: Factory) -> T:
sub_dependencies = [
self._get_unlocked(dependency)
for dependency in dep_provider.dependencies
for dependency in factory.dependencies
]
if dep_provider.type is FactoryType.GENERATOR:
generator = dep_provider.source(*sub_dependencies)
self.exits.append(Exit(dep_provider.type, generator))
if factory.type is FactoryType.GENERATOR:
generator = factory.source(*sub_dependencies)
self.exits.append(Exit(factory.type, generator))
return next(generator)
elif dep_provider.type is FactoryType.FACTORY:
return dep_provider.source(*sub_dependencies)
elif dep_provider.type is FactoryType.VALUE:
return dep_provider.source
elif factory.type is FactoryType.FACTORY:
return factory.source(*sub_dependencies)
elif factory.type is FactoryType.VALUE:
return factory.source
else:
raise ValueError(f"Unsupported type {dep_provider.type}")
raise ValueError(f"Unsupported type {factory.type}")

def get(self, dependency_type: Type[T]) -> T:
lock = self.lock
Expand All @@ -107,7 +101,7 @@ def _get_unlocked(self, dependency_type: Type[T]) -> T:
if not self.parent_container:
raise ValueError(f"No provider found for {dependency_type!r}")
return self.parent_container.get(dependency_type)
solved = self._get_self(provider)
solved = self._get_from_self(provider)
self.context[dependency_type] = solved
return solved

Expand Down
37 changes: 30 additions & 7 deletions src/dishka/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import List
import inspect
from typing import Any, List

from .dependency_source import DependencySource
from .dependency_source import Alias, Decorator, DependencySource, Factory


def is_dependency_source(attribute: Any) -> bool:
return isinstance(attribute, DependencySource)


class Provider:
Expand All @@ -18,8 +23,26 @@ class Provider:
"""

def __init__(self):
self.dependency_sources: List[DependencySource] = [
getattr(self, name)
for name, attr in vars(type(self)).items()
if isinstance(attr, DependencySource)
]
self.factories: List[Factory] = []
self.aliases: List[Alias] = []
self.decorators: List[Decorator] = []
self._init_dependency_sources()

def _init_dependency_sources(self) -> None:
processed_types = {}

source: DependencySource
for name, source in inspect.getmembers(self, is_dependency_source):
if source.provides in processed_types:
raise ValueError(
f"Type {source.provides} is registered multiple times "
f"in the same {Provider} by attributes "
f"{processed_types[source.provides]!r} and {name!r}",
)
if isinstance(source, Alias):
self.aliases.append(source)
if isinstance(source, Factory):
self.factories.append(source)
if isinstance(source, Decorator):
self.decorators.append(source)
processed_types[source.provides] = name
64 changes: 36 additions & 28 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Any, List, NewType, Type

from .dependency_source import Alias, Decorator, Factory
from .dependency_source import Factory
from .provider import Provider
from .scope import BaseScope

Expand All @@ -23,40 +23,48 @@ def get_provider(self, dependency: Any) -> Factory:
def make_registries(
*providers: Provider, scopes: Type[BaseScope],
) -> List[Registry]:
dep_scopes = {}
dep_scopes: dict[Type, BaseScope] = {}
alias_sources = {}
for provider in providers:
for source in provider.dependency_sources:
if hasattr(source, "scope"):
dep_scopes[source.provides] = source.scope
for source in provider.factories:
dep_scopes[source.provides] = source.scope
for source in provider.aliases:
alias_sources[source.provides] = source.source

registries = {scope: Registry(scope) for scope in scopes}
decorator_depth: dict[Type, int] = defaultdict(int)

for provider in providers:
for source in provider.dependency_sources:
for source in provider.factories:
scope = source.scope
registries[scope].add_provider(source)
for source in provider.aliases:
alias_source = source.source
visited_types = [alias_source]
while alias_source not in dep_scopes:
alias_source = alias_sources[alias_source]
if alias_source in visited_types:
raise ValueError(f"Cycle aliases detected {visited_types}")
visited_types.append(alias_source)
scope = dep_scopes[alias_source]
dep_scopes[source.provides] = scope
source = source.as_factory(scope)
registries[scope].add_provider(source)
for source in provider.decorators:
provides = source.provides
if isinstance(source, Factory):
scope = source.scope
elif isinstance(source, Alias):
scope = dep_scopes[source.source]
dep_scopes[provides] = scope
source = source.as_factory(scope)
elif isinstance(source, Decorator):
scope = dep_scopes[provides]
registry = registries[scope]
undecorated_type = NewType(
f"{provides.__name__}@{decorator_depth[provides]}",
source.provides,
)
decorator_depth[provides] += 1
old_provider = registry.get_provider(provides)
old_provider.provides = undecorated_type
registry.add_provider(old_provider)
source = source.as_factory(
scope, undecorated_type,
)
else:
raise ValueError("Unknown dependency source type")
scope = dep_scopes[provides]
registry = registries[scope]
undecorated_type = NewType(
f"{provides.__name__}@{decorator_depth[provides]}",
source.provides,
)
decorator_depth[provides] += 1
old_provider = registry.get_provider(provides)
old_provider.provides = undecorated_type
registry.add_provider(old_provider)
source = source.as_factory(
scope, undecorated_type,
)
registries[scope].add_provider(source)

return list(registries.values())
33 changes: 33 additions & 0 deletions tests/container/test_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from dishka import Provider, Scope, alias, make_container, provide


class AliasProvider(Provider):
@provide(scope=Scope.APP)
def provide_int(self) -> int:
return 42

aliased_complex = alias(source=float, provides=complex)
aliased_float = alias(source=int, provides=float)


def test_alias():
with make_container(AliasProvider()) as container:
assert container.get(float) == container.get(int)


def test_alias_to_alias():
with make_container(AliasProvider()) as container:
assert container.get(complex) == container.get(int)


class CycleProvider(Provider):
a = alias(source=int, provides=bool)
b = alias(source=bool, provides=float)
c = alias(source=float, provides=int)


def test_cycle():
with pytest.raises(ValueError):
make_container(CycleProvider())
27 changes: 23 additions & 4 deletions tests/container/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from dishka import Provider, Scope, alias, decorate, make_container, provide


Expand All @@ -21,9 +23,11 @@ def __init__(self, a: A):
def test_simple():
class MyProvider(Provider):
a = provide(A, scope=Scope.APP)

class DProvider(Provider):
ad = decorate(ADecorator, provides=A)

with make_container(MyProvider()) as container:
with make_container(MyProvider(), DProvider()) as container:
a = container.get(A)
assert isinstance(a, ADecorator)
assert isinstance(a.a, A)
Expand All @@ -35,11 +39,12 @@ class MyProvider(Provider):
a1 = alias(source=A2, provides=A1)
a = alias(source=A1, provides=A)

class DProvider(Provider):
@decorate
def decorated(self, a: A1) -> A1:
return ADecorator(a)

with make_container(MyProvider()) as container:
with make_container(MyProvider(), DProvider()) as container:
a1 = container.get(A1)
assert isinstance(a1, ADecorator)
assert isinstance(a1.a, A2)
Expand All @@ -52,13 +57,27 @@ def decorated(self, a: A1) -> A1:
assert a is a1


def test_double():
def test_double_error():
class MyProvider(Provider):
a = provide(A, scope=Scope.APP)
ad = decorate(ADecorator, provides=A)
ad2 = decorate(ADecorator, provides=A)

with make_container(MyProvider()) as container:
with pytest.raises(ValueError):
MyProvider()


def test_double_ok():
class MyProvider(Provider):
a = provide(A, scope=Scope.APP)

class DProvider(Provider):
ad = decorate(ADecorator, provides=A)

class D2Provider(Provider):
ad2 = decorate(ADecorator, provides=A)

with make_container(MyProvider(), DProvider(), D2Provider()) as container:
a = container.get(A)
assert isinstance(a, ADecorator)
assert isinstance(a.a, ADecorator)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
def test_provider_init():
class MyProvider(Provider):
a = alias(source=int, provides=bool)
b = provide(lambda: False, scope=Scope.APP, provides=bool)

@provide(scope=Scope.REQUEST)
def foo(self, x: bool) -> str:
return f"{x}"

provider = MyProvider()
assert len(provider.dependency_sources) == 3
assert len(provider.factories) == 1
assert len(provider.aliases) == 1


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3a89466

Please sign in to comment.