From 18d538feea921c625a536cfe60ce8379d5bd71a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Ne=C5=A1=C5=A5=C3=A1k?= Date: Sun, 3 Nov 2024 22:37:38 +0100 Subject: [PATCH] Allow annotated in classes --- injector/__init__.py | 6 +++++- injector_test.py | 50 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/injector/__init__.py b/injector/__init__.py index 5c9cf9c..d9bdd5f 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -1208,6 +1208,9 @@ def _is_new_union_type(instance: Any) -> bool: new_union_type = getattr(types, 'UnionType', None) return new_union_type is not None and isinstance(instance, new_union_type) + def _is_package_annotation(annotation: Any) -> bool: + return _is_specialization(annotation, Annotated) and (_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__) + spec = inspect.getfullargspec(callable) try: @@ -1238,7 +1241,8 @@ def _is_new_union_type(instance: Any) -> bool: bindings.pop(spec.varkw, None) for k, v in list(bindings.items()): - if _is_specialization(v, Annotated): + # extract metadata only from Inject and NonInject + if _is_package_annotation(v): v, metadata = v.__origin__, v.__metadata__ bindings[k] = v else: diff --git a/injector_test.py b/injector_test.py index 3d98254..c50f810 100644 --- a/injector_test.py +++ b/injector_test.py @@ -11,6 +11,7 @@ """Functional tests for the "Injector" dependency injection framework.""" from contextlib import contextmanager +from dataclasses import dataclass from typing import Any, NewType, Optional, Union import abc import sys @@ -1754,3 +1755,52 @@ def configure(binder): injector = Injector([configure]) assert injector.get(foo) == 123 assert injector.get(bar) == 456 + + +def test_annotated_integration_with_annotated(): + UserID = Annotated[int, 'user_id'] + + @inject + class TestClass: + def __init__(self, user_id: UserID): + self.user_id = user_id + + def configure(binder): + binder.bind(UserID, to=123) + + injector = Injector([configure]) + + test_class = injector.get(TestClass) + assert test_class.user_id == 123 + + +def test_newtype_integration_with_annotated(): + UserID = NewType('UserID', int) + + @inject + class TestClass: + def __init__(self, user_id: UserID): + self.user_id = user_id + + def configure(binder): + binder.bind(UserID, to=123) + + injector = Injector([configure]) + + test_class = injector.get(TestClass) + assert test_class.user_id == 123 + +def test_dataclass_annotated_parameter(): + Foo = Annotated[int, object()] + + def configure(binder): + binder.bind(Foo, to=123) + + @inject + @dataclass + class MyClass: + foo: Foo + + injector = Injector([configure]) + instance = injector.get(MyClass) + assert instance.foo == 123 \ No newline at end of file