Skip to content

Commit 0256913

Browse files
authored
Merge pull request #284 from eirikurt/fix-multibind-scopes
fix: Multibind scopes
2 parents f3b8a49 + 9faa32a commit 0256913

File tree

2 files changed

+156
-64
lines changed

2 files changed

+156
-64
lines changed

injector/__init__.py

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,25 @@
2222
import threading
2323
import types
2424
from abc import ABCMeta, abstractmethod
25-
from collections import namedtuple
25+
from dataclasses import dataclass
2626
from typing import (
27+
TYPE_CHECKING,
2728
Any,
2829
Callable,
29-
cast,
3030
Dict,
31+
Generator,
3132
Generic,
32-
get_args,
3333
Iterable,
3434
List,
3535
Optional,
36-
overload,
3736
Set,
3837
Tuple,
3938
Type,
4039
TypeVar,
41-
TYPE_CHECKING,
4240
Union,
41+
cast,
42+
get_args,
43+
overload,
4344
)
4445

4546
try:
@@ -52,13 +53,13 @@
5253
# canonical. Since this typing_extensions import is only for mypy it'll work even without
5354
# typing_extensions actually installed so all's good.
5455
if TYPE_CHECKING:
55-
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
56+
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
5657
else:
5758
# Ignoring errors here as typing_extensions stub doesn't know about those things yet
5859
try:
59-
from typing import _AnnotatedAlias, Annotated, get_type_hints
60+
from typing import Annotated, _AnnotatedAlias, get_type_hints
6061
except ImportError:
61-
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
62+
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
6263

6364

6465
__author__ = 'Alec Thomas <[email protected]>'
@@ -340,39 +341,60 @@ def __repr__(self) -> str:
340341

341342

342343
@private
343-
class ListOfProviders(Provider, Generic[T]):
344+
class MultiBinder(Provider, Generic[T]):
344345
"""Provide a list of instances via other Providers."""
345346

346-
_providers: List[Provider[T]]
347+
_multi_bindings: List['Binding']
347348

348-
def __init__(self) -> None:
349-
self._providers = []
349+
def __init__(self, parent: 'Binder') -> None:
350+
self._multi_bindings = []
351+
self._binder = Binder(parent.injector, auto_bind=False, parent=parent)
350352

351-
def append(self, provider: Provider[T]) -> None:
352-
self._providers.append(provider)
353+
def append(self, provider: Provider[T], scope: Type['Scope']) -> None:
354+
# HACK: generate a pseudo-type for this element in the list.
355+
# This is needed for scopes to work properly. Some, like the Singleton scope,
356+
# key instances by type, so we need one that is unique to this binding.
357+
pseudo_type = type(f"multibind-type-{id(provider)}", (provider.__class__,), {})
358+
self._multi_bindings.append(Binding(pseudo_type, provider, scope))
359+
360+
def get_scoped_providers(self, injector: 'Injector') -> Generator[Provider[T], None, None]:
361+
for binding in self._multi_bindings:
362+
if (
363+
isinstance(binding.provider, ClassProvider)
364+
and binding.scope is NoScope
365+
and self._binder.parent
366+
and self._binder.parent.has_explicit_binding_for(binding.provider._cls)
367+
):
368+
parent_binding, _ = self._binder.parent.get_binding(binding.provider._cls)
369+
scope_binding, _ = self._binder.parent.get_binding(parent_binding.scope)
370+
else:
371+
scope_binding, _ = self._binder.get_binding(binding.scope)
372+
scope_instance: Scope = scope_binding.provider.get(injector)
373+
provider_instance = scope_instance.get(binding.interface, binding.provider)
374+
yield provider_instance
353375

354376
def __repr__(self) -> str:
355-
return '%s(%r)' % (type(self).__name__, self._providers)
377+
return '%s(%r)' % (type(self).__name__, self._multi_bindings)
356378

357379

358-
class MultiBindProvider(ListOfProviders[List[T]]):
380+
class MultiBindProvider(MultiBinder[List[T]]):
359381
"""Used by :meth:`Binder.multibind` to flatten results of providers that
360382
return sequences."""
361383

362384
def get(self, injector: 'Injector') -> List[T]:
363385
result: List[T] = []
364-
for provider in self._providers:
386+
for provider in self.get_scoped_providers(injector):
365387
instances: List[T] = _ensure_iterable(provider.get(injector))
366388
result.extend(instances)
367389
return result
368390

369391

370-
class MapBindProvider(ListOfProviders[Dict[str, T]]):
392+
class MapBindProvider(MultiBinder[Dict[str, T]]):
371393
"""A provider for map bindings."""
372394

373395
def get(self, injector: 'Injector') -> Dict[str, T]:
374396
map: Dict[str, T] = {}
375-
for provider in self._providers:
397+
for provider in self.get_scoped_providers(injector):
376398
map.update(provider.get(injector))
377399
return map
378400

@@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
387409
return {self._key: self._provider.get(injector)}
388410

389411

390-
_BindingBase = namedtuple('_BindingBase', 'interface provider scope')
412+
@dataclass
413+
class _BindingBase:
414+
interface: type
415+
provider: Provider
416+
scope: Type['Scope']
391417

392418

393419
@private
@@ -531,44 +557,51 @@ def multibind(
531557
532558
:param scope: Optional Scope in which to bind.
533559
"""
534-
if interface not in self._bindings:
535-
provider: ListOfProviders
536-
if (
537-
isinstance(interface, dict)
538-
or isinstance(interface, type)
539-
and issubclass(interface, dict)
540-
or _get_origin(_punch_through_alias(interface)) is dict
541-
):
542-
provider = MapBindProvider()
543-
else:
544-
provider = MultiBindProvider()
545-
binding = self.create_binding(interface, provider, scope)
546-
self._bindings[interface] = binding
547-
else:
548-
binding = self._bindings[interface]
549-
provider = binding.provider
550-
assert isinstance(provider, ListOfProviders)
551-
552-
if isinstance(provider, MultiBindProvider) and isinstance(to, list):
560+
multi_binder = self._get_multi_binder(interface)
561+
if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list):
553562
try:
554563
element_type = get_args(_punch_through_alias(interface))[0]
555564
except IndexError:
556565
raise InvalidInterface(
557566
f"Use typing.List[T] or list[T] to specify the element type of the list"
558567
)
559568
for element in to:
560-
provider.append(self.provider_for(element_type, element))
561-
elif isinstance(provider, MapBindProvider) and isinstance(to, dict):
569+
element_binding = self.create_binding(element_type, element, scope)
570+
multi_binder.append(element_binding.provider, element_binding.scope)
571+
elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict):
562572
try:
563573
value_type = get_args(_punch_through_alias(interface))[1]
564574
except IndexError:
565575
raise InvalidInterface(
566576
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
567577
)
568578
for key, value in to.items():
569-
provider.append(KeyValueProvider(key, self.provider_for(value_type, value)))
579+
element_binding = self.create_binding(value_type, value, scope)
580+
multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
570581
else:
571-
provider.append(self.provider_for(interface, to))
582+
element_binding = self.create_binding(interface, to, scope)
583+
multi_binder.append(element_binding.provider, element_binding.scope)
584+
585+
def _get_multi_binder(self, interface: type) -> MultiBinder:
586+
multi_binder: MultiBinder
587+
if interface not in self._bindings:
588+
if (
589+
isinstance(interface, dict)
590+
or isinstance(interface, type)
591+
and issubclass(interface, dict)
592+
or _get_origin(_punch_through_alias(interface)) is dict
593+
):
594+
multi_binder = MapBindProvider(self)
595+
else:
596+
multi_binder = MultiBindProvider(self)
597+
binding = self.create_binding(interface, multi_binder)
598+
self._bindings[interface] = binding
599+
else:
600+
binding = self._bindings[interface]
601+
assert isinstance(binding.provider, MultiBinder)
602+
multi_binder = binding.provider
603+
604+
return multi_binder
572605

573606
def install(self, module: _InstallableModuleType) -> None:
574607
"""Install a module into this binder.
@@ -611,10 +644,10 @@ def create_binding(
611644
self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None
612645
) -> Binding:
613646
provider = self.provider_for(interface, to)
614-
scope = scope or getattr(to or interface, '__scope__', NoScope)
647+
scope = scope or getattr(to or interface, '__scope__', None)
615648
if isinstance(scope, ScopeDecorator):
616649
scope = scope.scope
617-
return Binding(interface, provider, scope)
650+
return Binding(interface, provider, scope or NoScope)
618651

619652
def provider_for(self, interface: Any, to: Any = None) -> Provider:
620653
base_type = _punch_through_alias(interface)
@@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
696729
# The special interface is added here so that requesting a special
697730
# interface with auto_bind disabled works
698731
if self._auto_bind or self._is_special_interface(interface):
699-
binding = ImplicitBinding(*self.create_binding(interface))
732+
binding = ImplicitBinding(**self.create_binding(interface).__dict__)
700733
self._bindings[interface] = binding
701734
return binding, self
702735

@@ -817,7 +850,7 @@ def __repr__(self) -> str:
817850
class NoScope(Scope):
818851
"""An unscoped provider."""
819852

820-
def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]:
853+
def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]:
821854
return provider
822855

823856

injector_test.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
"""Functional tests for the "Injector" dependency injection framework."""
1212

13-
from contextlib import contextmanager
14-
from dataclasses import dataclass
15-
from typing import Any, NewType, Optional, Union
1613
import abc
1714
import sys
1815
import threading
1916
import traceback
2017
import warnings
18+
from contextlib import contextmanager
19+
from dataclasses import dataclass
20+
from typing import Any, NewType, Optional, Union
2121

2222
if sys.version_info >= (3, 9):
2323
from typing import Annotated
@@ -29,32 +29,32 @@
2929
import pytest
3030

3131
from injector import (
32+
AssistedBuilder,
3233
Binder,
3334
CallError,
35+
CircularDependency,
36+
ClassAssistedBuilder,
37+
ClassProvider,
38+
Error,
3439
Inject,
3540
Injector,
41+
InstanceProvider,
42+
InvalidInterface,
43+
Module,
3644
NoInject,
45+
ProviderOf,
3746
Scope,
38-
InstanceProvider,
39-
ClassProvider,
47+
ScopeDecorator,
48+
SingletonScope,
49+
UnknownArgument,
50+
UnsatisfiedRequirement,
4051
get_bindings,
4152
inject,
4253
multiprovider,
4354
noninjectable,
55+
provider,
4456
singleton,
4557
threadlocal,
46-
UnsatisfiedRequirement,
47-
CircularDependency,
48-
Module,
49-
SingletonScope,
50-
ScopeDecorator,
51-
AssistedBuilder,
52-
provider,
53-
ProviderOf,
54-
ClassAssistedBuilder,
55-
Error,
56-
UnknownArgument,
57-
InvalidInterface,
5858
)
5959

6060

@@ -723,6 +723,65 @@ def configure_dict(binder: Binder):
723723
Injector([configure_dict])
724724

725725

726+
def test_multibind_types_respect_the_bound_type_scope() -> None:
727+
def configure(binder: Binder) -> None:
728+
binder.bind(PluginA, to=PluginA, scope=singleton)
729+
binder.multibind(List[Plugin], to=PluginA)
730+
731+
injector = Injector([configure])
732+
first_list = injector.get(List[Plugin])
733+
second_list = injector.get(List[Plugin])
734+
child_injector = injector.create_child_injector()
735+
third_list = child_injector.get(List[Plugin])
736+
737+
assert first_list[0] is second_list[0]
738+
assert third_list[0] is second_list[0]
739+
740+
741+
def test_multibind_list_scopes_applies_to_the_bound_items() -> None:
742+
def configure(binder: Binder) -> None:
743+
binder.multibind(List[Plugin], to=PluginA, scope=singleton)
744+
binder.multibind(List[Plugin], to=PluginB)
745+
binder.multibind(List[Plugin], to=[PluginC], scope=singleton)
746+
747+
injector = Injector([configure])
748+
first_list = injector.get(List[Plugin])
749+
second_list = injector.get(List[Plugin])
750+
751+
assert first_list is not second_list
752+
assert first_list[0] is second_list[0]
753+
assert first_list[1] is not second_list[1]
754+
assert first_list[2] is second_list[2]
755+
756+
757+
def test_multibind_dict_scopes_applies_to_the_bound_items() -> None:
758+
def configure(binder: Binder) -> None:
759+
binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton)
760+
binder.multibind(Dict[str, Plugin], to={'b': PluginB})
761+
binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton)
762+
763+
injector = Injector([configure])
764+
first_dict = injector.get(Dict[str, Plugin])
765+
second_dict = injector.get(Dict[str, Plugin])
766+
767+
assert first_dict is not second_dict
768+
assert first_dict['a'] is second_dict['a']
769+
assert first_dict['b'] is not second_dict['b']
770+
assert first_dict['c'] is second_dict['c']
771+
772+
773+
def test_multibind_scopes_does_not_apply_to_the_type_globally() -> None:
774+
def configure(binder: Binder) -> None:
775+
binder.multibind(List[Plugin], to=PluginA, scope=singleton)
776+
777+
injector = Injector([configure])
778+
plugins = injector.get(List[Plugin])
779+
780+
assert plugins[0] is not injector.get(PluginA)
781+
assert plugins[0] is not injector.get(Plugin)
782+
assert injector.get(PluginA) is not injector.get(PluginA)
783+
784+
726785
def test_regular_bind_and_provider_dont_work_with_multibind():
727786
# We only want multibind and multiprovider to work to avoid confusion
728787

0 commit comments

Comments
 (0)