Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,16 @@ def _is_injection_annotation(annotation: Any) -> bool:
_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__
)

def _recreate_annotated_origin(annotated_type: Any) -> Any:
# Creates `Annotated[type, annotation]` from `Inject[Annotated[type, annotation]]`,
# to support the injection of annotated types with the `Inject[]` annotation.
origin = annotated_type.__origin__
for metadata in annotated_type.__metadata__: # pragma: no branch
if metadata in (_inject_marker, _noinject_marker):
break
origin = Annotated[origin, metadata]
return origin

spec = inspect.getfullargspec(callable)

try:
Expand Down Expand Up @@ -1245,7 +1255,7 @@ def _is_injection_annotation(annotation: Any) -> bool:
for k, v in list(bindings.items()):
# extract metadata only from Inject and NonInject
if _is_injection_annotation(v):
v, metadata = v.__origin__, v.__metadata__
v, metadata = _recreate_annotated_origin(v), v.__metadata__
bindings[k] = v
else:
metadata = tuple()
Expand Down
138 changes: 137 additions & 1 deletion injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,67 @@ def function(a: int) -> 'InvalidForwardReference':
assert get_bindings(function) == {'a': int}


def test_gets_bindings_for_annotated_type_with_inject_decorator() -> None:
UserID = Annotated[int, 'user_id']

@inject
def function(a: UserID, b: str) -> None:
pass

assert get_bindings(function) == {'a': UserID, 'b': str}


def test_gets_bindings_of_annotated_type_with_inject_annotation() -> None:
UserID = Annotated[int, 'user_id']

def function(a: Inject[UserID], b: Inject[str]) -> None:
pass

assert get_bindings(function) == {'a': UserID, 'b': str}


def test_gets_bindings_of_new_type_with_inject_annotation() -> None:
Name = NewType('Name', str)

@inject
def function(a: Name, b: str) -> None:
pass

assert get_bindings(function) == {'a': Name, 'b': str}


def test_gets_bindings_of_inject_annotation_with_new_type() -> None:
def function(a: Inject[Name], b: str) -> None:
pass

assert get_bindings(function) == {'a': Name}


def test_get_bindings_of_nested_noinject_inject_annotation() -> None:
# This is not how this is intended to be used
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
pass

assert get_bindings(function) == {}


def test_get_bindings_of_nested_noinject_inject_annotation_and_inject_decorator() -> None:
# This is not how this is intended to be used
@inject
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
pass

assert get_bindings(function) == {}


def test_get_bindings_of_nested_inject_annotations() -> None:
# This is not how this is intended to be used
def function(a: Inject[Inject[int]]) -> None:
pass

assert get_bindings(function) == {'a': int}


# Tests https://github.com/alecthomas/injector/issues/202
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
def test_get_bindings_for_pep_604():
Expand Down Expand Up @@ -1785,21 +1846,80 @@ def configure(binder):

def test_annotated_integration_with_annotated():
UserID = Annotated[int, 'user_id']
UserAge = Annotated[int, 'user_age']

@inject
class TestClass:
def __init__(self, user_id: UserID):
def __init__(self, user_id: UserID, user_age: UserAge):
self.user_id = user_id
self.user_age = user_age

def configure(binder):
binder.bind(UserID, to=123)
binder.bind(UserAge, to=32)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123
assert test_class.user_age == 32


def test_inject_annotation_with_annotated_type():
UserID = Annotated[int, 'user_id']
UserAge = Annotated[int, 'user_age']

class TestClass:
def __init__(self, user_id: Inject[UserID], user_age: Inject[UserAge]):
self.user_id = user_id
self.user_age = user_age

def configure(binder):
binder.bind(UserID, to=123)
binder.bind(UserAge, to=32)
binder.bind(int, to=456)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123
assert test_class.user_age == 32


def test_inject_annotation_with_nested_annotated_type():
UserID = Annotated[int, 'user_id']
SpecialUserID = Annotated[UserID, 'special_user_id']

class TestClass:
def __init__(self, user_id: Inject[SpecialUserID]):
self.user_id = user_id

def configure(binder):
binder.bind(SpecialUserID, to=123)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123


def test_noinject_annotation_with_annotated_type():
UserID = Annotated[int, 'user_id']

@inject
class TestClass:
def __init__(self, user_id: NoInject[UserID] = None):
self.user_id = user_id

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id is None


def test_newtype_integration_with_annotated():
UserID = NewType('UserID', int)

Expand All @@ -1817,6 +1937,22 @@ def configure(binder):
assert test_class.user_id == 123


def test_newtype_with_injection_annotation():
UserID = NewType('UserID', int)

class TestClass:
def __init__(self, user_id: Inject[UserID]):
self.user_id = user_id

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])

test_class = injector.get(TestClass)
assert test_class.user_id == 123


def test_dataclass_annotated_parameter():
Foo = Annotated[int, object()]

Expand Down