Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/wiring/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from typing import Annotated


class Service:
Expand All @@ -12,12 +13,18 @@ class Container(containers.DeclarativeContainer):

service = providers.Factory(Service)


# You can place marker on parameter default value
@inject
def main(service: Service = Provide[Container.service]) -> None:
...


# Also, you can place marker with typing.Annotated
@inject
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None:
...


if __name__ == "__main__":
container = Container()
container.wire(modules=[__name__])
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ numpy
scipy
boto3
mypy_boto3_s3
typing_extensions

-r requirements-ext.txt
45 changes: 35 additions & 10 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class GenericMeta(type):
else:
GenericAlias = None

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
try:
from typing_extensions import Annotated, get_args, get_origin
except ImportError:
Annotated = object()

# For preventing NameError. Never executes
def get_args(hint):
return ()

def get_origin(tp):
return None

try:
import fastapi.params
Expand Down Expand Up @@ -548,6 +562,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
setattr(patched.member, patched.name, patched.marker)


def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
if get_origin(parameter.annotation) is Annotated:
marker = get_args(parameter.annotation)[1]
else:
marker = parameter.default

if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
return None

if _is_fastapi_depends(marker):
marker = marker.dependency

if not isinstance(marker, _Marker):
return None

return marker


def _fetch_reference_injections( # noqa: C901
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand All @@ -573,17 +605,10 @@ def _fetch_reference_injections( # noqa: C901
injections = {}
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker) \
and not _is_fastapi_depends(parameter.default):
continue
marker = _extract_marker(parameter)

marker = parameter.default

if _is_fastapi_depends(marker):
marker = marker.dependency

if not isinstance(marker, _Marker):
continue
if marker is None:
continue

if isinstance(marker, Closing):
marker = marker.provider
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/samples/wiringfastapi/web.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys

from typing_extensions import Annotated

from fastapi import FastAPI, Depends
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi.security import HTTPBasic, HTTPBasicCredentials
Expand Down Expand Up @@ -27,6 +29,11 @@ async def index(service: Service = Depends(Provide[Container.service])):
result = await service.process()
return {"result": result}

@app.api_route('/annotated')
@inject
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
result = await service.process()
return {'result': result}

@app.get("/auth")
@inject
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/samples/wiringflask/web.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing_extensions import Annotated

from flask import Flask, jsonify, request, current_app, session, g
from flask import _request_ctx_stack, _app_ctx_stack
from dependency_injector import containers, providers
Expand Down Expand Up @@ -28,5 +30,12 @@ def index(service: Service = Provide[Container.service]):
return jsonify({"result": result})


@app.route("/annotated")
@inject
def annotated(service: Annotated[Service, Provide[Container.service]]):
result = service.process()
return jsonify({'result': result})


container = Container()
container.wire(modules=[__name__])
14 changes: 14 additions & 0 deletions tests/unit/wiring/test_fastapi_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ async def process(self):
assert response.json() == {"result": "Foo"}


@mark.asyncio
async def test_depends_with_annotated(async_client: AsyncClient):
class ServiceMock:
async def process(self):
return "Foo"

with web.container.service.override(ServiceMock()):
response = await async_client.get("/")

assert response.status_code == 200
assert response.json() == {"result": "Foo"}



@mark.asyncio
async def test_depends_injection(async_client: AsyncClient):
response = await async_client.get("/auth", auth=("john_smith", "secret"))
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/wiring/test_flask_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ def test_wiring_with_flask():

assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}


def test_wiring_with_annotated():
client = web.app.test_client()

with web.app.app_context():
response = client.get("/annotated")

assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}