Skip to content

Commit 44a6a68

Browse files
take common code from ContextLocalResource to Resource
1 parent 066d228 commit 44a6a68

File tree

2 files changed

+60
-117
lines changed

2 files changed

+60
-117
lines changed

src/dependency_injector/providers.pyx

Lines changed: 39 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,32 +3837,28 @@ cdef class Resource(Provider):
38373837

38383838
async def _handle_async_cm(self, obj) -> None:
38393839
try:
3840-
self._resource = resource = await obj.__aenter__()
3841-
self._shutdowner = obj.__aexit__
3840+
resource = await obj.__aenter__()
38423841
return resource
38433842
except:
38443843
self._initialized = False
38453844
raise
38463845

3847-
async def _provide_async(self, future) -> None:
3848-
try:
3849-
obj = await future
3850-
3851-
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
3852-
self._resource = await obj.__aenter__()
3853-
self._shutdowner = obj.__aexit__
3854-
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
3855-
self._resource = obj.__enter__()
3856-
self._shutdowner = obj.__exit__
3857-
else:
3858-
self._resource = obj
3859-
self._shutdowner = None
3846+
async def _provide_async(self, future):
3847+
obj = await future
38603848

3861-
return self._resource
3862-
except:
3863-
self._initialized = False
3864-
raise
3849+
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
3850+
resource = await obj.__aenter__()
3851+
shutdowner = obj.__aexit__
3852+
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
3853+
resource = obj.__enter__()
3854+
shutdowner = obj.__exit__
3855+
else:
3856+
resource = obj
3857+
shutdowner = None
38653858

3859+
return resource, shutdowner
3860+
3861+
38663862
cpdef object _provide(self, tuple args, dict kwargs):
38673863
if self._initialized:
38683864
return self._resource
@@ -3880,14 +3876,18 @@ cdef class Resource(Provider):
38803876

38813877
if __is_future_or_coroutine(obj):
38823878
self._initialized = True
3883-
self._resource = resource = ensure_future(self._provide_async(obj))
3884-
return resource
3879+
future_result = asyncio.Future()
3880+
future = ensure_future(self._provide_async(obj))
3881+
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
3882+
self._resource = future_result
3883+
return self._resource
38853884
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
38863885
self._resource = obj.__enter__()
38873886
self._shutdowner = obj.__exit__
38883887
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
38893888
self._initialized = True
38903889
self._resource = resource = ensure_future(self._handle_async_cm(obj))
3890+
self._shutdowner = obj.__aexit__
38913891
return resource
38923892
else:
38933893
self._resource = obj
@@ -3896,14 +3896,27 @@ cdef class Resource(Provider):
38963896
self._initialized = True
38973897
return self._resource
38983898

3899+
def _async_init_instance(self, future_result, result):
3900+
try:
3901+
resource, shutdowner = result.result()
3902+
except Exception as exception:
3903+
self._resource = None
3904+
self._shutdowner = None
3905+
self._initialized = False
3906+
future_result.set_exception(exception)
3907+
else:
3908+
self._resource = resource
3909+
self._shutdowner = shutdowner
3910+
future_result.set_result(resource)
3911+
38993912

39003913
cdef class ContextLocalResource(Resource):
39013914
_none = object()
39023915

39033916
def __init__(self, provides=None, *args, **kwargs):
39043917
self._initialized_context_var = ContextVar("_initialized_context_var", default=False)
3905-
self._resource_context_var = ContextVar("_resource_context_var", default=self._none)
3906-
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none)
3918+
self._resource_context_var = ContextVar("_resource_context_var", default=None)
3919+
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None)
39073920
super().__init__(provides, *args, **kwargs)
39083921

39093922
@property
@@ -3945,7 +3958,7 @@ cdef class ContextLocalResource(Resource):
39453958
return NULL_AWAITABLE
39463959
return
39473960

3948-
if self._shutdowner != self._none:
3961+
if self._shutdowner != None:
39493962
future = self._shutdowner(None, None, None)
39503963
if __is_future_or_coroutine(future):
39513964
self._reset_all_contex_vars()
@@ -3958,79 +3971,8 @@ cdef class ContextLocalResource(Resource):
39583971

39593972
def _reset_all_contex_vars(self):
39603973
self._initialized=False
3961-
self._resource = self._none
3962-
self._shutdowner = self._none
3963-
3964-
async def _handle_async_cm(self, obj) -> None:
3965-
resource = await obj.__aenter__()
3966-
return resource
3967-
3968-
async def _provide_async(self, future):
3969-
obj = await future
3970-
3971-
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
3972-
resource = await obj.__aenter__()
3973-
shutdowner = obj.__aexit__
3974-
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
3975-
resource = obj.__enter__()
3976-
shutdowner = obj.__exit__
3977-
else:
3978-
resource = obj
3979-
shutdowner = self._none
3980-
3981-
return resource, shutdowner
3982-
3983-
3984-
cpdef object _provide(self, tuple args, dict kwargs):
3985-
if self._initialized:
3986-
return self._resource
3987-
obj = __call(
3988-
self._provides,
3989-
args,
3990-
self._args,
3991-
self._args_len,
3992-
kwargs,
3993-
self._kwargs,
3994-
self._kwargs_len,
3995-
self._async_mode,
3996-
)
3997-
3998-
if __is_future_or_coroutine(obj):
3999-
future_result = asyncio.Future()
4000-
future = ensure_future(self._provide_async(obj))
4001-
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
4002-
return future_result
4003-
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
4004-
resource = obj.__enter__()
4005-
self._resource = resource
4006-
self._initialized = True
4007-
self._shutdowner = obj.__exit__
4008-
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
4009-
resource = ensure_future(self._handle_async_cm(obj))
4010-
self._resource = resource
4011-
self._initialized = True
4012-
self._shutdowner = obj.__aexit__
4013-
return resource
4014-
else:
4015-
self._resource = obj
4016-
self._initialized = True
4017-
self._shutdowner = self._none
4018-
4019-
return self._resource
4020-
4021-
def _async_init_instance(self, future_result, result):
4022-
try:
4023-
resource, shutdowner = result.result()
4024-
except Exception as exception:
4025-
self._resource = self._none
4026-
self._shutdowner = self._none
4027-
self._initialized = False
4028-
future_result.set_exception(exception)
4029-
else:
4030-
self._resource = resource
4031-
self._initialized = True
4032-
self._shutdowner = shutdowner
4033-
future_result.set_result(resource)
3974+
self._resource = None
3975+
self._shutdowner = None
40343976

40353977

40363978
cdef class Container(Provider):

tests/unit/providers/resource/test_context_local_resource_py38.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,27 @@ class Container(containers.DeclarativeContainer):
8888
context_local_resource = providers.ContextLocalResource(_init)
8989
async_context_local_resource = providers.ContextLocalResource(_async_init)
9090

91+
async def run_in_context():
92+
obj = await container.async_context_local_resource()
93+
return obj
94+
9195
container = Container()
92-
obj1 = await container.async_context_local_resource()
93-
obj2 = await container.async_context_local_resource()
94-
assert obj1 != obj2
9596

96-
obj3 = container.context_local_resource()
97-
obj4 = container.context_local_resource()
97+
obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context())
98+
assert obj1 != obj2
9899

100+
obj3 = await container.async_context_local_resource()
101+
obj4 = await container.async_context_local_resource()
99102
assert obj3 == obj4
100103

104+
obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context())
105+
assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized
106+
107+
obj7 = container.context_local_resource()
108+
obj8 = container.context_local_resource()
109+
110+
assert obj7 == obj8
111+
101112

102113
def test_init_function():
103114
def _init():
@@ -329,37 +340,27 @@ def test_call_with_context_args():
329340

330341

331342
def test_fluent_interface():
332-
provider = providers.ContextLocalResource(init_fn) \
333-
.add_args(1, 2) \
334-
.add_kwargs(a3=3, a4=4)
343+
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4)
335344
assert provider() == ((1, 2), {"a3": 3, "a4": 4})
336345

337346

338347
def test_set_args():
339-
provider = providers.ContextLocalResource(init_fn) \
340-
.add_args(1, 2) \
341-
.set_args(3, 4)
348+
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4)
342349
assert provider.args == (3, 4)
343350

344351

345352
def test_clear_args():
346-
provider = providers.ContextLocalResource(init_fn) \
347-
.add_args(1, 2) \
348-
.clear_args()
353+
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args()
349354
assert provider.args == tuple()
350355

351356

352357
def test_set_kwargs():
353-
provider = providers.ContextLocalResource(init_fn) \
354-
.add_kwargs(a1="i1", a2="i2") \
355-
.set_kwargs(a3="i3", a4="i4")
358+
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4")
356359
assert provider.kwargs == {"a3": "i3", "a4": "i4"}
357360

358361

359362
def test_clear_kwargs():
360-
provider = providers.ContextLocalResource(init_fn) \
361-
.add_kwargs(a1="i1", a2="i2") \
362-
.clear_kwargs()
363+
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs()
363364
assert provider.kwargs == {}
364365

365366

0 commit comments

Comments
 (0)