Skip to content

Commit

Permalink
Allow annotated in classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Filip Nešťák committed Nov 3, 2024
1 parent 0e9905e commit 7e7c431
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
6 changes: 5 additions & 1 deletion injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,9 @@ def _is_new_union_type(instance: Any) -> bool:
new_union_type = getattr(types, 'UnionType', None)
return new_union_type is not None and isinstance(instance, new_union_type)

def _is_package_annotation(annotation: Any) -> bool:
return _is_specialization(annotation, Annotated) and (_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__)

spec = inspect.getfullargspec(callable)

try:
Expand Down Expand Up @@ -1238,7 +1241,8 @@ def _is_new_union_type(instance: Any) -> bool:
bindings.pop(spec.varkw, None)

for k, v in list(bindings.items()):
if _is_specialization(v, Annotated):
# extract metadata only from Inject and NonInject
if _is_package_annotation(v):
v, metadata = v.__origin__, v.__metadata__
bindings[k] = v
else:
Expand Down
34 changes: 34 additions & 0 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,3 +1754,37 @@ def configure(binder):
injector = Injector([configure])
assert injector.get(foo) == 123
assert injector.get(bar) == 456


def test_annotated_integration_with_annotated():
UserID = Annotated[int, 'user_id']

@inject
class TestClass:
def __init__(self, user_id: UserID):
self.user_id = user_id

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123


def test_newtype_integration_with_annotated():
UserID = NewType('UserID', int)

@inject
class TestClass:
def __init__(self, user_id: UserID):
self.user_id = user_id

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123

0 comments on commit 7e7c431

Please sign in to comment.