Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkgs/standards/autoapi/autoapi/v2/cfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"__autoapi_allow_anon__",
"__autoapi_verb_aliases__",
"__autoapi_verb_alias_policy__",
"__autoapi_security_deps__",
}

# Routing configuration attributes
Expand Down
9 changes: 8 additions & 1 deletion pkgs/standards/autoapi/autoapi/v2/impl/routes_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ def _register_routes_and_rpcs( # noqa: N802 – bound as method
else (flat_router, APIRouter(prefix=nested_pref, tags=[f"nested-{resource}"]))
)

sec_dep_cb = getattr(model, "__autoapi_security_deps__", None)
if callable(sec_dep_cb):
_security_deps = list(sec_dep_cb())
else:
_security_deps = list(sec_dep_cb or [])

# ---------- RBAC guard -------------------------------------------
def _guard(scope: str):
async def inner(request: Request):
Expand Down Expand Up @@ -507,7 +513,8 @@ def _direct_call(_m, _p, _db=db):

# mount on routers
for rtr in routers:
deps = [_guard(m_id_canon)]
deps = list(_security_deps)
deps.append(_guard(m_id_canon))
if m_id_canon not in self._allow_anon:
deps.insert(0, self._authn_dep)
print(f"Mounting route {path} for verb {verb} on router {rtr}")
Expand Down
6 changes: 6 additions & 0 deletions pkgs/standards/autoapi/autoapi/v2/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
ResponseExtrasProvider,
list_response_extras_providers,
)
from .security_deps_provider import (
SecurityDepsProvider,
list_security_deps_providers,
)

from .op_verb_alias_provider import OpVerbAliasProvider, list_verb_alias_providers

Expand All @@ -91,6 +95,7 @@ def hex(self):
"NestedPathProvider",
"AllowAnonProvider",
"ResponseExtrasProvider",
"SecurityDepsProvider",
# builtin types
"MethodType",
"SimpleNamespace",
Expand Down Expand Up @@ -152,4 +157,5 @@ def hex(self):
"OpVerbAliasProvider",
"list_verb_alias_providers",
"list_response_extras_providers",
"list_security_deps_providers",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Callable, ClassVar, Iterable

from .table_config_provider import TableConfigProvider

_SECURITY_DEPS_PROVIDERS: set[type] = set()


class SecurityDepsProvider(TableConfigProvider):
"""Models that define extra security dependencies for routes."""

__autoapi_security_deps__: ClassVar[Iterable | Callable[[], Iterable]] = ()

def __init_subclass__(cls, **kw):
super().__init_subclass__(**kw)
_SECURITY_DEPS_PROVIDERS.add(cls)


def list_security_deps_providers():
return sorted(_SECURITY_DEPS_PROVIDERS, key=lambda c: c.__name__)
52 changes: 52 additions & 0 deletions pkgs/standards/autoapi/tests/i9n/test_security_deps_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.testclient import TestClient
from sqlalchemy import Column, String, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from autoapi.v2 import AutoAPI, Base
from autoapi.v2.mixins import GUIDPk


def _build_client():
Base.metadata.clear()

class Item(Base, GUIDPk):
__tablename__ = "items"
name = Column(String, nullable=False)

@classmethod
def __autoapi_security_deps__(cls):
def verify(x_token: str = Header(None)):
if x_token != "secret":
raise HTTPException(status_code=401)

return [Depends(verify)]

engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)

def get_db():
with SessionLocal() as session:
yield session

api = AutoAPI(base=Base, include={Item}, get_db=get_db)
app = FastAPI()
app.include_router(api.router)
api.initialize_sync()
return TestClient(app)


def test_security_deps_enforced():
client = _build_client()
payload = {"name": "thing"}
assert client.post("/item", json=payload).status_code == 401
res = client.post("/item", json=payload, headers={"x-token": "secret"})
assert res.status_code == 201
iid = res.json()["id"]
assert client.get(f"/item/{iid}").status_code == 401
assert client.get(f"/item/{iid}", headers={"x-token": "secret"}).status_code == 200