Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 80 additions & 47 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,25 @@
import threading
import types
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
cast,
Dict,
Generator,
Generic,
get_args,
Iterable,
List,
Optional,
overload,
Set,
Tuple,
Type,
TypeVar,
TYPE_CHECKING,
Union,
cast,
get_args,
overload,
)

try:
Expand All @@ -52,13 +53,13 @@
# canonical. Since this typing_extensions import is only for mypy it'll work even without
# typing_extensions actually installed so all's good.
if TYPE_CHECKING:
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
else:
# Ignoring errors here as typing_extensions stub doesn't know about those things yet
try:
from typing import _AnnotatedAlias, Annotated, get_type_hints
from typing import Annotated, _AnnotatedAlias, get_type_hints
except ImportError:
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints


__author__ = 'Alec Thomas <[email protected]>'
Expand Down Expand Up @@ -340,39 +341,60 @@ def __repr__(self) -> str:


@private
class ListOfProviders(Provider, Generic[T]):
class MultiBinder(Provider, Generic[T]):
"""Provide a list of instances via other Providers."""

_providers: List[Provider[T]]
_multi_bindings: List['Binding']

def __init__(self) -> None:
self._providers = []
def __init__(self, parent: 'Binder') -> None:
self._multi_bindings = []
self._binder = Binder(parent.injector, auto_bind=False, parent=parent)

def append(self, provider: Provider[T]) -> None:
self._providers.append(provider)
def append(self, provider: Provider[T], scope: Type['Scope']) -> None:
# HACK: generate a pseudo-type for this element in the list.
# This is needed for scopes to work properly. Some, like the Singleton scope,
# key instances by type, so we need one that is unique to this binding.
pseudo_type = type(f"multibind-type-{id(provider)}", (provider.__class__,), {})
Comment on lines +354 to +357
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit on in what case this is needed? Is it required for the case when multiple different types are bound with the singleton scope in the same multibound type? Just like PluginA and PluginC in the case below:

def test_multibind_dict_scopes_applies_to_the_bound_items() -> None:
    def configure(binder: Binder) -> None:
        binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton)
        binder.multibind(Dict[str, Plugin], to={'b': PluginB})
        binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly tried to reproduce the problem without the pseudo-type, by instead passing the resolved type, which is Plugin for the example above. I got a weird problem in that test case that I haven't yet resolved, but for multibound lists every case I can come up with seems to work as expected. This makes me suspect it's not really needed.

self._multi_bindings.append(Binding(pseudo_type, provider, scope))

def get_scoped_providers(self, injector: 'Injector') -> Generator[Provider[T], None, None]:
for binding in self._multi_bindings:
if (
isinstance(binding.provider, ClassProvider)
and binding.scope is NoScope
and self._binder.parent
and self._binder.parent.has_explicit_binding_for(binding.provider._cls)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on removing the has_explicit_binding_for() condition, since an explicit binding (that should be respected) may exist on parent.parent.parent as well. That change also makes sure that classes decorated with @singleton are instantiated as far up in the hierarchy as possible, which makes sure the instance is shared among all child injectors.

Do you have any argument against doing this? All tests still pass.

):
parent_binding, _ = self._binder.parent.get_binding(binding.provider._cls)
Copy link
Collaborator

@davidparsson davidparsson Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should respect the parent provider here as well, for the same reasons as for scopes. I'll fix this.

scope_binding, _ = self._binder.parent.get_binding(parent_binding.scope)
else:
scope_binding, _ = self._binder.get_binding(binding.scope)
scope_instance: Scope = scope_binding.provider.get(injector)
provider_instance = scope_instance.get(binding.interface, binding.provider)
yield provider_instance

def __repr__(self) -> str:
return '%s(%r)' % (type(self).__name__, self._providers)
return '%s(%r)' % (type(self).__name__, self._multi_bindings)


class MultiBindProvider(ListOfProviders[List[T]]):
class MultiBindProvider(MultiBinder[List[T]]):
"""Used by :meth:`Binder.multibind` to flatten results of providers that
return sequences."""

def get(self, injector: 'Injector') -> List[T]:
result: List[T] = []
for provider in self._providers:
for provider in self.get_scoped_providers(injector):
instances: List[T] = _ensure_iterable(provider.get(injector))
result.extend(instances)
return result


class MapBindProvider(ListOfProviders[Dict[str, T]]):
class MapBindProvider(MultiBinder[Dict[str, T]]):
"""A provider for map bindings."""

def get(self, injector: 'Injector') -> Dict[str, T]:
map: Dict[str, T] = {}
for provider in self._providers:
for provider in self.get_scoped_providers(injector):
map.update(provider.get(injector))
return map

Expand All @@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
return {self._key: self._provider.get(injector)}


_BindingBase = namedtuple('_BindingBase', 'interface provider scope')
@dataclass
class _BindingBase:
interface: type
provider: Provider
scope: Type['Scope']


@private
Expand Down Expand Up @@ -531,44 +557,51 @@ def multibind(

:param scope: Optional Scope in which to bind.
"""
if interface not in self._bindings:
provider: ListOfProviders
if (
isinstance(interface, dict)
or isinstance(interface, type)
and issubclass(interface, dict)
or _get_origin(_punch_through_alias(interface)) is dict
):
provider = MapBindProvider()
else:
provider = MultiBindProvider()
binding = self.create_binding(interface, provider, scope)
self._bindings[interface] = binding
else:
binding = self._bindings[interface]
provider = binding.provider
assert isinstance(provider, ListOfProviders)

if isinstance(provider, MultiBindProvider) and isinstance(to, list):
multi_binder = self._get_multi_binder(interface)
if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list):
try:
element_type = get_args(_punch_through_alias(interface))[0]
except IndexError:
raise InvalidInterface(
f"Use typing.List[T] or list[T] to specify the element type of the list"
)
for element in to:
provider.append(self.provider_for(element_type, element))
elif isinstance(provider, MapBindProvider) and isinstance(to, dict):
element_binding = self.create_binding(element_type, element, scope)
multi_binder.append(element_binding.provider, element_binding.scope)
elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict):
try:
value_type = get_args(_punch_through_alias(interface))[1]
except IndexError:
raise InvalidInterface(
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
)
for key, value in to.items():
provider.append(KeyValueProvider(key, self.provider_for(value_type, value)))
element_binding = self.create_binding(value_type, value, scope)
multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
else:
provider.append(self.provider_for(interface, to))
element_binding = self.create_binding(interface, to, scope)
multi_binder.append(element_binding.provider, element_binding.scope)

def _get_multi_binder(self, interface: type) -> MultiBinder:
multi_binder: MultiBinder
if interface not in self._bindings:
if (
isinstance(interface, dict)
or isinstance(interface, type)
and issubclass(interface, dict)
or _get_origin(_punch_through_alias(interface)) is dict
):
multi_binder = MapBindProvider(self)
else:
multi_binder = MultiBindProvider(self)
binding = self.create_binding(interface, multi_binder)
self._bindings[interface] = binding
else:
binding = self._bindings[interface]
assert isinstance(binding.provider, MultiBinder)
multi_binder = binding.provider

return multi_binder

def install(self, module: _InstallableModuleType) -> None:
"""Install a module into this binder.
Expand Down Expand Up @@ -611,10 +644,10 @@ def create_binding(
self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None
) -> Binding:
provider = self.provider_for(interface, to)
scope = scope or getattr(to or interface, '__scope__', NoScope)
scope = scope or getattr(to or interface, '__scope__', None)
if isinstance(scope, ScopeDecorator):
scope = scope.scope
return Binding(interface, provider, scope)
return Binding(interface, provider, scope or NoScope)

def provider_for(self, interface: Any, to: Any = None) -> Provider:
base_type = _punch_through_alias(interface)
Expand Down Expand Up @@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
# The special interface is added here so that requesting a special
# interface with auto_bind disabled works
if self._auto_bind or self._is_special_interface(interface):
binding = ImplicitBinding(*self.create_binding(interface))
binding = ImplicitBinding(**self.create_binding(interface).__dict__)
self._bindings[interface] = binding
return binding, self

Expand Down Expand Up @@ -817,7 +850,7 @@ def __repr__(self) -> str:
class NoScope(Scope):
"""An unscoped provider."""

def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]:
def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]:
return provider


Expand Down
93 changes: 76 additions & 17 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

"""Functional tests for the "Injector" dependency injection framework."""

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, NewType, Optional, Union
import abc
import sys
import threading
import traceback
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, NewType, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
Expand All @@ -29,32 +29,32 @@
import pytest

from injector import (
AssistedBuilder,
Binder,
CallError,
CircularDependency,
ClassAssistedBuilder,
ClassProvider,
Error,
Inject,
Injector,
InstanceProvider,
InvalidInterface,
Module,
NoInject,
ProviderOf,
Scope,
InstanceProvider,
ClassProvider,
ScopeDecorator,
SingletonScope,
UnknownArgument,
UnsatisfiedRequirement,
get_bindings,
inject,
multiprovider,
noninjectable,
provider,
singleton,
threadlocal,
UnsatisfiedRequirement,
CircularDependency,
Module,
SingletonScope,
ScopeDecorator,
AssistedBuilder,
provider,
ProviderOf,
ClassAssistedBuilder,
Error,
UnknownArgument,
InvalidInterface,
)


Expand Down Expand Up @@ -723,6 +723,65 @@ def configure_dict(binder: Binder):
Injector([configure_dict])


def test_multibind_types_respect_the_bound_type_scope() -> None:
def configure(binder: Binder) -> None:
binder.bind(PluginA, to=PluginA, scope=singleton)
binder.multibind(List[Plugin], to=PluginA)

injector = Injector([configure])
first_list = injector.get(List[Plugin])
second_list = injector.get(List[Plugin])
child_injector = injector.create_child_injector()
third_list = child_injector.get(List[Plugin])

assert first_list[0] is second_list[0]
assert third_list[0] is second_list[0]


def test_multibind_list_scopes_applies_to_the_bound_items() -> None:
def configure(binder: Binder) -> None:
binder.multibind(List[Plugin], to=PluginA, scope=singleton)
binder.multibind(List[Plugin], to=PluginB)
binder.multibind(List[Plugin], to=[PluginC], scope=singleton)

injector = Injector([configure])
first_list = injector.get(List[Plugin])
second_list = injector.get(List[Plugin])

assert first_list is not second_list
assert first_list[0] is second_list[0]
assert first_list[1] is not second_list[1]
assert first_list[2] is second_list[2]


def test_multibind_dict_scopes_applies_to_the_bound_items() -> None:
def configure(binder: Binder) -> None:
binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton)
binder.multibind(Dict[str, Plugin], to={'b': PluginB})
binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton)

injector = Injector([configure])
first_dict = injector.get(Dict[str, Plugin])
second_dict = injector.get(Dict[str, Plugin])

assert first_dict is not second_dict
assert first_dict['a'] is second_dict['a']
assert first_dict['b'] is not second_dict['b']
assert first_dict['c'] is second_dict['c']


def test_multibind_scopes_does_not_apply_to_the_type_globally() -> None:
def configure(binder: Binder) -> None:
binder.multibind(List[Plugin], to=PluginA, scope=singleton)

injector = Injector([configure])
plugins = injector.get(List[Plugin])

assert plugins[0] is not injector.get(PluginA)
assert plugins[0] is not injector.get(Plugin)
assert injector.get(PluginA) is not injector.get(PluginA)


def test_regular_bind_and_provider_dont_work_with_multibind():
# We only want multibind and multiprovider to work to avoid confusion

Expand Down