2222import threading
2323import types
2424from abc import ABCMeta , abstractmethod
25- from collections import namedtuple
25+ from dataclasses import dataclass
2626from typing import (
27+ TYPE_CHECKING ,
2728 Any ,
2829 Callable ,
29- cast ,
3030 Dict ,
31+ Generator ,
3132 Generic ,
32- get_args ,
3333 Iterable ,
3434 List ,
3535 Optional ,
36- overload ,
3736 Set ,
3837 Tuple ,
3938 Type ,
4039 TypeVar ,
41- TYPE_CHECKING ,
4240 Union ,
41+ cast ,
42+ get_args ,
43+ overload ,
4344)
4445
4546try :
5253# canonical. Since this typing_extensions import is only for mypy it'll work even without
5354# typing_extensions actually installed so all's good.
5455if TYPE_CHECKING :
55- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
56+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
5657else :
5758 # Ignoring errors here as typing_extensions stub doesn't know about those things yet
5859 try :
59- from typing import _AnnotatedAlias , Annotated , get_type_hints
60+ from typing import Annotated , _AnnotatedAlias , get_type_hints
6061 except ImportError :
61- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
62+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
6263
6364
6465__author__ = 'Alec Thomas <[email protected] >' @@ -340,39 +341,60 @@ def __repr__(self) -> str:
340341
341342
342343@private
343- class ListOfProviders (Provider , Generic [T ]):
344+ class MultiBinder (Provider , Generic [T ]):
344345 """Provide a list of instances via other Providers."""
345346
346- _providers : List [Provider [ T ] ]
347+ _multi_bindings : List ['Binding' ]
347348
348- def __init__ (self ) -> None :
349- self ._providers = []
349+ def __init__ (self , parent : 'Binder' ) -> None :
350+ self ._multi_bindings = []
351+ self ._binder = Binder (parent .injector , auto_bind = False , parent = parent )
350352
351- def append (self , provider : Provider [T ]) -> None :
352- self ._providers .append (provider )
353+ def append (self , provider : Provider [T ], scope : Type ['Scope' ]) -> None :
354+ # HACK: generate a pseudo-type for this element in the list.
355+ # This is needed for scopes to work properly. Some, like the Singleton scope,
356+ # key instances by type, so we need one that is unique to this binding.
357+ pseudo_type = type (f"multibind-type-{ id (provider )} " , (provider .__class__ ,), {})
358+ self ._multi_bindings .append (Binding (pseudo_type , provider , scope ))
359+
360+ def get_scoped_providers (self , injector : 'Injector' ) -> Generator [Provider [T ], None , None ]:
361+ for binding in self ._multi_bindings :
362+ if (
363+ isinstance (binding .provider , ClassProvider )
364+ and binding .scope is NoScope
365+ and self ._binder .parent
366+ and self ._binder .parent .has_explicit_binding_for (binding .provider ._cls )
367+ ):
368+ parent_binding , _ = self ._binder .parent .get_binding (binding .provider ._cls )
369+ scope_binding , _ = self ._binder .parent .get_binding (parent_binding .scope )
370+ else :
371+ scope_binding , _ = self ._binder .get_binding (binding .scope )
372+ scope_instance : Scope = scope_binding .provider .get (injector )
373+ provider_instance = scope_instance .get (binding .interface , binding .provider )
374+ yield provider_instance
353375
354376 def __repr__ (self ) -> str :
355- return '%s(%r)' % (type (self ).__name__ , self ._providers )
377+ return '%s(%r)' % (type (self ).__name__ , self ._multi_bindings )
356378
357379
358- class MultiBindProvider (ListOfProviders [List [T ]]):
380+ class MultiBindProvider (MultiBinder [List [T ]]):
359381 """Used by :meth:`Binder.multibind` to flatten results of providers that
360382 return sequences."""
361383
362384 def get (self , injector : 'Injector' ) -> List [T ]:
363385 result : List [T ] = []
364- for provider in self ._providers :
386+ for provider in self .get_scoped_providers ( injector ) :
365387 instances : List [T ] = _ensure_iterable (provider .get (injector ))
366388 result .extend (instances )
367389 return result
368390
369391
370- class MapBindProvider (ListOfProviders [Dict [str , T ]]):
392+ class MapBindProvider (MultiBinder [Dict [str , T ]]):
371393 """A provider for map bindings."""
372394
373395 def get (self , injector : 'Injector' ) -> Dict [str , T ]:
374396 map : Dict [str , T ] = {}
375- for provider in self ._providers :
397+ for provider in self .get_scoped_providers ( injector ) :
376398 map .update (provider .get (injector ))
377399 return map
378400
@@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
387409 return {self ._key : self ._provider .get (injector )}
388410
389411
390- _BindingBase = namedtuple ('_BindingBase' , 'interface provider scope' )
412+ @dataclass
413+ class _BindingBase :
414+ interface : type
415+ provider : Provider
416+ scope : Type ['Scope' ]
391417
392418
393419@private
@@ -531,44 +557,51 @@ def multibind(
531557
532558 :param scope: Optional Scope in which to bind.
533559 """
534- if interface not in self ._bindings :
535- provider : ListOfProviders
536- if (
537- isinstance (interface , dict )
538- or isinstance (interface , type )
539- and issubclass (interface , dict )
540- or _get_origin (_punch_through_alias (interface )) is dict
541- ):
542- provider = MapBindProvider ()
543- else :
544- provider = MultiBindProvider ()
545- binding = self .create_binding (interface , provider , scope )
546- self ._bindings [interface ] = binding
547- else :
548- binding = self ._bindings [interface ]
549- provider = binding .provider
550- assert isinstance (provider , ListOfProviders )
551-
552- if isinstance (provider , MultiBindProvider ) and isinstance (to , list ):
560+ multi_binder = self ._get_multi_binder (interface )
561+ if isinstance (multi_binder , MultiBindProvider ) and isinstance (to , list ):
553562 try :
554563 element_type = get_args (_punch_through_alias (interface ))[0 ]
555564 except IndexError :
556565 raise InvalidInterface (
557566 f"Use typing.List[T] or list[T] to specify the element type of the list"
558567 )
559568 for element in to :
560- provider .append (self .provider_for (element_type , element ))
561- elif isinstance (provider , MapBindProvider ) and isinstance (to , dict ):
569+ element_binding = self .create_binding (element_type , element , scope )
570+ multi_binder .append (element_binding .provider , element_binding .scope )
571+ elif isinstance (multi_binder , MapBindProvider ) and isinstance (to , dict ):
562572 try :
563573 value_type = get_args (_punch_through_alias (interface ))[1 ]
564574 except IndexError :
565575 raise InvalidInterface (
566576 f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
567577 )
568578 for key , value in to .items ():
569- provider .append (KeyValueProvider (key , self .provider_for (value_type , value )))
579+ element_binding = self .create_binding (value_type , value , scope )
580+ multi_binder .append (KeyValueProvider (key , element_binding .provider ), element_binding .scope )
570581 else :
571- provider .append (self .provider_for (interface , to ))
582+ element_binding = self .create_binding (interface , to , scope )
583+ multi_binder .append (element_binding .provider , element_binding .scope )
584+
585+ def _get_multi_binder (self , interface : type ) -> MultiBinder :
586+ multi_binder : MultiBinder
587+ if interface not in self ._bindings :
588+ if (
589+ isinstance (interface , dict )
590+ or isinstance (interface , type )
591+ and issubclass (interface , dict )
592+ or _get_origin (_punch_through_alias (interface )) is dict
593+ ):
594+ multi_binder = MapBindProvider (self )
595+ else :
596+ multi_binder = MultiBindProvider (self )
597+ binding = self .create_binding (interface , multi_binder )
598+ self ._bindings [interface ] = binding
599+ else :
600+ binding = self ._bindings [interface ]
601+ assert isinstance (binding .provider , MultiBinder )
602+ multi_binder = binding .provider
603+
604+ return multi_binder
572605
573606 def install (self , module : _InstallableModuleType ) -> None :
574607 """Install a module into this binder.
@@ -611,10 +644,10 @@ def create_binding(
611644 self , interface : type , to : Any = None , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ] = None
612645 ) -> Binding :
613646 provider = self .provider_for (interface , to )
614- scope = scope or getattr (to or interface , '__scope__' , NoScope )
647+ scope = scope or getattr (to or interface , '__scope__' , None )
615648 if isinstance (scope , ScopeDecorator ):
616649 scope = scope .scope
617- return Binding (interface , provider , scope )
650+ return Binding (interface , provider , scope or NoScope )
618651
619652 def provider_for (self , interface : Any , to : Any = None ) -> Provider :
620653 base_type = _punch_through_alias (interface )
@@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
696729 # The special interface is added here so that requesting a special
697730 # interface with auto_bind disabled works
698731 if self ._auto_bind or self ._is_special_interface (interface ):
699- binding = ImplicitBinding (* self .create_binding (interface ))
732+ binding = ImplicitBinding (** self .create_binding (interface ). __dict__ )
700733 self ._bindings [interface ] = binding
701734 return binding , self
702735
@@ -817,7 +850,7 @@ def __repr__(self) -> str:
817850class NoScope (Scope ):
818851 """An unscoped provider."""
819852
820- def get (self , unused_key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
853+ def get (self , key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
821854 return provider
822855
823856
0 commit comments