From 97801eb1c5495e9fe7d5286e0d9743a137c64448 Mon Sep 17 00:00:00 2001 From: cobycloud <25079070+cobycloud@users.noreply.github.com> Date: Mon, 25 Aug 2025 22:53:31 -0500 Subject: [PATCH 1/2] feat: add hook contexts for authn models --- docs/docs/concept/auto_authn_v2_overview.md | 76 +++++++++++++++++++ .../auto_authn/auto_authn/v2/orm/api_key.py | 46 +++++++++-- .../auto_authn/auto_authn/v2/orm/auth_code.py | 54 ++++++------- .../auto_authn/v2/orm/auth_session.py | 64 +++++++++++----- .../auto_authn/auto_authn/v2/orm/client.py | 10 ++- .../auto_authn/v2/orm/device_code.py | 68 +++++++++-------- .../v2/orm/pushed_authorization_request.py | 20 +++-- .../auto_authn/v2/orm/revoked_token.py | 14 +++- .../auto_authn/auto_authn/v2/orm/service.py | 21 +++-- .../auto_authn/v2/orm/service_key.py | 57 +++++++++++--- .../auto_authn/auto_authn/v2/orm/tables.py | 2 +- .../auto_authn/auto_authn/v2/orm/tenant.py | 11 +-- .../auto_authn/auto_authn/v2/orm/user.py | 35 +++++++-- 13 files changed, 358 insertions(+), 120 deletions(-) create mode 100644 docs/docs/concept/auto_authn_v2_overview.md diff --git a/docs/docs/concept/auto_authn_v2_overview.md b/docs/docs/concept/auto_authn_v2_overview.md new file mode 100644 index 0000000000..c2861ee06a --- /dev/null +++ b/docs/docs/concept/auto_authn_v2_overview.md @@ -0,0 +1,76 @@ +# auto_authn.v2 Overview + +## Router Endpoint Modules +- `rfc7591.py` – client registration endpoint +- `rfc9126.py` – pushed authorization request endpoint +- `rfc8628.py` – device authorization endpoint +- `rfc8932.py` – enhanced authorization server metadata endpoint +- `rfc8414.py` – OAuth 2.0 authorization server metadata endpoint +- `rfc7009.py` – token revocation endpoint +- `rfc8693.py` – token exchange endpoint +- `oidc_userinfo.py` – OpenID Connect `/userinfo` endpoint +- `oidc_discovery.py` – OpenID discovery endpoints +- `routers/auth_flows.py` – aggregates authentication and authorization routers +- `routers/authn/__init__.py` – authentication routes (register/login/logout) +- `routers/authz/__init__.py` – authorization routes (token, introspection, etc.) +- `routers/crud.py` – AutoAPI-generated CRUD router for ORM models + +## Schema Modules +- `routers/schemas.py` – request/response models for auth flows +- `rfc7591.py` – `ClientMetadata` for dynamic client registration +- `rfc8628.py` – `DeviceAuthIn`, `DeviceAuthOut`, and `DeviceGrantForm` +- `rfc9396.py` – `AuthorizationDetail` for rich authorization requests +- `rfc8693.py` – `TokenType`, `TokenExchangeRequest`, and `TokenExchangeResponse` +- `orm/` modules – SQLAlchemy models (`Tenant`, `User`, `Client`, `Service`, `ApiKey`, `ServiceKey`, `AuthSession`, `AuthCode`, `DeviceCode`, `RevokedToken`, `PushedAuthorizationRequest`) + +### Persistence vs. Virtual Schemas +- **Persistent**: ORM models from `orm/` (listed above) +- **Virtual**: Pydantic/in-memory classes such as `RegisterIn`, `TokenPair`, `ClientMetadata`, `DeviceAuthIn`, `DeviceAuthOut`, `DeviceGrantForm`, `AuthorizationDetail`, `TokenType`, `TokenExchangeRequest`, and `TokenExchangeResponse` + +## Crypto Modules +- `crypto.py` – bcrypt password hashing and Ed25519 key management +- `rfc7515.py` – JSON Web Signature helpers +- `rfc7516.py` – JSON Web Encryption helpers +- `rfc7517.py` – loading signing and public JWKs +- `rfc7518.py` – supported JOSE algorithms list +- `rfc7519.py` – JWT encode/decode wrappers +- `rfc7638.py` – JWK thumbprint generation and verification +- `rfc7800.py` – confirmation claim and proof-of-possession utilities +- `rfc8291.py` – AES-128-GCM encryption/decryption for push messages +- `rfc8037.py` – EdDSA signing and verification helpers +- `rfc8705.py` – certificate thumbprint and binding validation +- `rfc9449_dpop.py` – DPoP proof creation and verification + +## ORM Tables, Columns, and Operations +| Table | Acols (stored columns) | Vcols (virtual/relationships) | Default Ops | Additional Ops | Hook Context | +|-------|------------------------|-------------------------------|-------------|----------------|--------------| +| `Tenant` | `id`, `slug`, `created_at`, `updated_at`, `name`, `email` | — | create, read, update, delete, list | — | — | +| `User` | `id`, `created_at`, `updated_at`, `tenant_id`, `username`, `email`, `password_hash`, `is_active` | `api_keys` relationship | create, read, update, delete, list | register, login | `hash_pw` pre-create/pre-update for password hashing | +| `Client` | `id`, `created_at`, `updated_at`, `tenant_id`, `client_secret_hash`, `redirect_uris`, `is_active` | — | create, read, update, delete, list | dynamic client registration (`rfc7591`) | optional `hash_client_secret` hook | +| `Service` | `id`, `created_at`, `updated_at`, `tenant_id`, `is_active`, `name` | `service_keys` relationship | create, read, update, delete, list | — | `encrypt_service_key` if needed | +| `ApiKey` | `id`, `created_at`, `last_used_at`, `valid_from`, `valid_to`, `label`, `digest`, `user_id` | `user` relationship | create, read, update, delete, list | generate/return raw key | pre-create `generate_api_key`, post-create `return_raw_key` | +| `ServiceKey` | same as `ApiKey` plus `service_id` | `service` relationship | create, read, update, delete, list | — | similar hooks as `ApiKey` | +| `AuthSession` | `id`, `user_id`, `tenant_id`, `username`, `auth_time`, `created_at`, `updated_at` | — | create, read, update, delete, list | — | — | +| `AuthCode` | `code`, `user_id`, `tenant_id`, `client_id`, `redirect_uri`, `code_challenge`, `nonce`, `scope`, `expires_at`, `claims`, `created_at`, `updated_at` | — | create, read, update, delete, list | — | — | +| `DeviceCode` | `device_code`, `user_code`, `client_id`, `expires_at`, `interval`, `authorized`, `user_id`, `tenant_id`, `created_at`, `updated_at` | — | create, read, update, delete, list | device authorization | `issue_device_code`, `notify_user_agent` hooks when persisted | +| `RevokedToken` | `token`, `created_at`, `updated_at` | — | create, read, update, delete, list | revoke | `store_revoked_token` pre-create | +| `PushedAuthorizationRequest` | `request_uri`, `params`, `expires_at`, `created_at`, `updated_at` | — | create, read, update, delete, list | pushed authorization request | `persist_par_request` pre-create | + +### Notes on Operations +- **op_alias**: no explicit overrides; CRUD uses default verbs. +- **schema_ctx**: use when virtual fields (e.g., `password`) cannot map directly to persistent columns. +- **Lifecycle hooks**: attach callables like `_pwd_backend.verify`, `_jwt.encode`, or custom crypto/providers to appropriate `pre_*` or `post_*` phases. + +## Hook Context Examples +| Endpoint | Lifecycle Hook | Purpose | +|----------|----------------|---------| +| `POST /register` | `pre_create` → `hash_pw` | Hash incoming password before persisting `User` record | +| `POST /login` | `pre_read` → `_pwd_backend.verify` | Verify password before issuing tokens | +| CRUD `ApiKey` | `pre_create` → `generate_api_key`, `post_create` → `return_raw_key` | Create digest and return plaintext key | +| `POST /token` | `pre_read`/`post_create` → `_pwd_backend.verify`, `_jwt.encode` | Validate client secrets and sign issued tokens | +| `POST /revoke` | `pre_create` → `store_revoked_token` | Persist revoked tokens to `RevokedToken` table | +| `POST /device_authorization` | `pre_create`/`post_create` → `issue_device_code`, `notify_user_agent` | Generate and optionally persist device/user codes | +| `POST /par` | `pre_create` → `persist_par_request` | Store pushed authorization request | +| `POST /token/exchange` | `post_create` → `_jwt.encode` | Sign exchanged tokens | +| `GET /userinfo` | `post_read` → `_jwt.encode` (optional) | Optionally sign the userinfo response | + diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/api_key.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/api_key.py index be7727c1b6..d55ad023fc 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/api_key.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/api_key.py @@ -2,9 +2,18 @@ from __future__ import annotations -from autoapi.v2.tables import ApiKey as ApiKeyBase -from autoapi.v2.types import UniqueConstraint, relationship -from autoapi.v2.mixins import UserMixin +import hashlib +import secrets + +from autoapi.v3.tables import ApiKey as ApiKeyBase +from autoapi.v3.types import UniqueConstraint, relationship +from autoapi.v3.mixins import UserMixin +from autoapi.v3.specs import IO, vcol +from autoapi.v3 import hook_ctx +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from .user import User class ApiKey(ApiKeyBase, UserMixin): @@ -13,11 +22,38 @@ class ApiKey(ApiKeyBase, UserMixin): {"extend_existing": True, "schema": "authn"}, ) - user = relationship( + _user = relationship( "auto_authn.v2.orm.tables.User", - back_populates="api_keys", + back_populates="_api_keys", lazy="joined", # optional: eager load to avoid N+1 ) + user: "User" = vcol( + read_producer=lambda obj, _ctx: obj._user, + io=IO(out_verbs=("read", "list")), + ) + + @hook_ctx(ops="create", phase="PRE_HANDLER") + async def _generate_digest(cls, ctx): + payload = ctx.get("payload") or {} + token = secrets.token_urlsafe(32) + payload["digest"] = hashlib.sha256(token.encode()).hexdigest() + ctx["raw_key"] = token + + @hook_ctx(ops="create", phase="POST_RESPONSE") + async def _return_raw_key(cls, ctx): + raw = ctx.get("raw_key") + result = ctx.get("result") + if not raw or result is None: + return + if hasattr(result, "model_dump"): + data = result.model_dump() + elif hasattr(result, "dict") and callable(result.dict): + data = result.dict() # type: ignore[call-arg] + else: + data = dict(result) + data["raw_key"] = raw + ctx["result"] = data + __all__ = ["ApiKey"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_code.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_code.py index 0e9d3833c8..29732ee2cf 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_code.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_code.py @@ -2,39 +2,39 @@ from __future__ import annotations -from autoapi.v2 import Base -from autoapi.v2.mixins import Timestamped -from autoapi.v2.types import Column, ForeignKey, JSON, PgUUID, String, TZDateTime +import datetime as dt +import uuid +from autoapi.v3.tables import Base +from autoapi.v3.mixins import TenantMixin, Timestamped, UserMixin +from autoapi.v3.specs import S, acol +from autoapi.v3.specs.storage_spec import ForeignKeySpec +from autoapi.v3.types import JSON, PgUUID, String, TZDateTime +from autoapi.v3 import op_ctx -class AuthCode(Base, Timestamped): + +class AuthCode(Base, Timestamped, UserMixin, TenantMixin): __tablename__ = "auth_codes" __table_args__ = ({"schema": "authn"},) - code = Column(String(128), primary_key=True) - user_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.users.id"), - nullable=False, - index=True, - ) - tenant_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.tenants.id"), - nullable=False, - index=True, + code: str = acol(storage=S(String(128), primary_key=True)) + client_id: uuid.UUID = acol( + storage=S( + PgUUID(as_uuid=True), + fk=ForeignKeySpec(target="authn.clients.id"), + nullable=False, + ) ) - client_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.clients.id"), - nullable=False, - ) - redirect_uri = Column(String(1000), nullable=False) - code_challenge = Column(String, nullable=True) - nonce = Column(String, nullable=True) - scope = Column(String, nullable=True) - expires_at = Column(TZDateTime, nullable=False) - claims = Column(JSON, nullable=True) + redirect_uri: str = acol(storage=S(String(1000), nullable=False)) + code_challenge: str | None = acol(storage=S(String, nullable=True)) + nonce: str | None = acol(storage=S(String, nullable=True)) + scope: str | None = acol(storage=S(String, nullable=True)) + expires_at: dt.datetime = acol(storage=S(TZDateTime, nullable=False)) + claims: dict | None = acol(storage=S(JSON, nullable=True)) + + @op_ctx(alias="exchange", target="delete", arity="member") + def exchange(cls, ctx, obj): + return obj __all__ = ["AuthCode"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_session.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_session.py index ae1dfcfafa..1d0ef1eedf 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_session.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/auth_session.py @@ -4,32 +4,56 @@ import datetime as dt -from autoapi.v2 import Base -from autoapi.v2.mixins import Timestamped -from autoapi.v2.types import Column, ForeignKey, PgUUID, String, TZDateTime +from autoapi.v3.tables import Base +from autoapi.v3.mixins import TenantMixin, Timestamped, UserMixin +from autoapi.v3.specs import S, acol +from autoapi.v3.types import String, TZDateTime +from autoapi.v3 import op_ctx, hook_ctx +from fastapi import HTTPException -class AuthSession(Base, Timestamped): +class AuthSession(Base, Timestamped, UserMixin, TenantMixin): __tablename__ = "sessions" __table_args__ = ({"schema": "authn"},) - id = Column(String(64), primary_key=True) - user_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.users.id"), - nullable=False, - index=True, - ) - tenant_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.tenants.id"), - nullable=False, - index=True, - ) - username = Column(String(120), nullable=False) - auth_time = Column( - TZDateTime, default=lambda: dt.datetime.now(dt.timezone.utc), nullable=False + id: str = acol(storage=S(String(64), primary_key=True)) + username: str = acol(storage=S(String(120), nullable=False)) + auth_time: dt.datetime = acol( + storage=S( + TZDateTime, nullable=False, default=lambda: dt.datetime.now(dt.timezone.utc) + ) ) + @hook_ctx(ops="login", phase="PRE_HANDLER") + async def _verify_credentials(cls, ctx): + from .user import User + + payload = ctx.get("payload") or {} + db = ctx.get("db") + username = payload.get("username") + password = payload.get("password") + if db is None or not username or not password: + raise HTTPException(status_code=400, detail="missing credentials") + + users = await User.handlers.list.core( + {"db": db, "payload": {"filters": {"username": username}}} + ) + user = users.items[0] if getattr(users, "items", None) else None + if user is None or not user.verify_password(password): + raise HTTPException(status_code=400, detail="invalid credentials") + + payload.pop("password", None) + payload["user_id"] = user.id + payload["tenant_id"] = user.tenant_id + payload["username"] = user.username + + @op_ctx(alias="login", target="create", arity="collection") + def login(cls, ctx): + pass + + @op_ctx(alias="logout", target="delete", arity="member") + def logout(cls, ctx, obj): + return obj + __all__ = ["AuthSession"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/client.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/client.py index d41741ea0c..df6d9eba67 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/client.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/client.py @@ -6,7 +6,8 @@ import uuid from typing import Final -from autoapi.v2.tables import Client as ClientBase +from autoapi.v3.tables import Client as ClientBase +from autoapi.v3 import hook_ctx from ..crypto import hash_pw from ..rfc8252 import validate_native_redirect_uri @@ -18,6 +19,13 @@ class Client(ClientBase): __table_args__ = ({"schema": "authn"},) + @hook_ctx(ops=("create", "update"), phase="PRE_HANDLER") + async def _hash_secret(cls, ctx): + payload = ctx.get("payload") or {} + secret = payload.pop("client_secret", None) + if secret: + payload["client_secret_hash"] = hash_pw(secret) + @classmethod def new( cls, diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/device_code.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/device_code.py index f1c7a8f670..60214ce974 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/device_code.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/device_code.py @@ -2,45 +2,53 @@ from __future__ import annotations -from autoapi.v2 import Base -from autoapi.v2.mixins import Timestamped -from autoapi.v2.types import ( - Boolean, - Column, - ForeignKey, - Integer, - PgUUID, - String, - TZDateTime, -) +import datetime as dt +import uuid + +from autoapi.v3.tables import Base +from autoapi.v3.mixins import Timestamped +from autoapi.v3.specs import S, acol +from autoapi.v3.specs.storage_spec import ForeignKeySpec +from autoapi.v3.types import Boolean, Integer, PgUUID, String, TZDateTime +from autoapi.v3 import op_ctx class DeviceCode(Base, Timestamped): __tablename__ = "device_codes" __table_args__ = ({"schema": "authn"},) - device_code = Column(String(128), primary_key=True) - user_code = Column(String(32), nullable=False, index=True) - client_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.clients.id"), - nullable=False, + device_code: str = acol(storage=S(String(128), primary_key=True)) + user_code: str = acol(storage=S(String(32), nullable=False, index=True)) + client_id: uuid.UUID = acol( + storage=S( + PgUUID(as_uuid=True), + fk=ForeignKeySpec(target="authn.clients.id"), + nullable=False, + ) ) - expires_at = Column(TZDateTime, nullable=False) - interval = Column(Integer, nullable=False) - authorized = Column(Boolean, default=False, nullable=False) - user_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.users.id"), - nullable=True, - index=True, + expires_at: dt.datetime = acol(storage=S(TZDateTime, nullable=False)) + interval: int = acol(storage=S(Integer, nullable=False)) + authorized: bool = acol(storage=S(Boolean, nullable=False, default=False)) + user_id: uuid.UUID | None = acol( + storage=S( + PgUUID(as_uuid=True), + fk=ForeignKeySpec(target="authn.users.id"), + nullable=True, + index=True, + ) ) - tenant_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.tenants.id"), - nullable=True, - index=True, + tenant_id: uuid.UUID | None = acol( + storage=S( + PgUUID(as_uuid=True), + fk=ForeignKeySpec(target="authn.tenants.id"), + nullable=True, + index=True, + ) ) + @op_ctx(alias="device_authorization", target="create", arity="collection") + def device_authorization(cls, ctx): + pass + __all__ = ["DeviceCode"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/pushed_authorization_request.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/pushed_authorization_request.py index af047b99d8..a2d1c1f901 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/pushed_authorization_request.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/pushed_authorization_request.py @@ -2,18 +2,26 @@ from __future__ import annotations -from autoapi.v2 import Base -from autoapi.v2.mixins import Timestamped -from autoapi.v2.types import Column, JSON, String, TZDateTime +import datetime as dt + +from autoapi.v3.tables import Base +from autoapi.v3.mixins import Timestamped +from autoapi.v3.specs import S, acol +from autoapi.v3.types import JSON, String, TZDateTime +from autoapi.v3 import op_ctx class PushedAuthorizationRequest(Base, Timestamped): __tablename__ = "par_requests" __table_args__ = ({"schema": "authn"},) - request_uri = Column(String(255), primary_key=True) - params = Column(JSON, nullable=False) - expires_at = Column(TZDateTime, nullable=False) + request_uri: str = acol(storage=S(String(255), primary_key=True)) + params: dict = acol(storage=S(JSON, nullable=False)) + expires_at: dt.datetime = acol(storage=S(TZDateTime, nullable=False)) + + @op_ctx(alias="par", target="create", arity="collection") + def par(cls, ctx): + pass __all__ = ["PushedAuthorizationRequest"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/revoked_token.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/revoked_token.py index cec699a0b1..7b19e81b37 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/revoked_token.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/revoked_token.py @@ -2,16 +2,22 @@ from __future__ import annotations -from autoapi.v2 import Base -from autoapi.v2.mixins import Timestamped -from autoapi.v2.types import Column, String +from autoapi.v3.tables import Base +from autoapi.v3.mixins import Timestamped +from autoapi.v3.specs import S, acol +from autoapi.v3.types import String +from autoapi.v3 import op_ctx class RevokedToken(Base, Timestamped): __tablename__ = "revoked_tokens" __table_args__ = ({"schema": "authn"},) - token = Column(String(512), primary_key=True) + token: str = acol(storage=S(String(512), primary_key=True)) + + @op_ctx(alias="revoke", target="create", arity="collection") + def revoke(cls, ctx): + pass __all__ = ["RevokedToken"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/service.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/service.py index b34a2f7dc8..13f4c26999 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/service.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/service.py @@ -2,9 +2,14 @@ from __future__ import annotations -from autoapi.v2 import Base -from autoapi.v2.mixins import GUIDPk, Timestamped, TenantBound, Principal, ActiveToggle -from autoapi.v2.types import Column, String, relationship +from autoapi.v3.tables import Base +from autoapi.v3.mixins import GUIDPk, Timestamped, TenantBound, Principal, ActiveToggle +from autoapi.v3.types import String, relationship +from autoapi.v3.specs import IO, S, acol, vcol +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from .service_key import ServiceKey class Service(Base, GUIDPk, Timestamped, TenantBound, Principal, ActiveToggle): @@ -12,12 +17,16 @@ class Service(Base, GUIDPk, Timestamped, TenantBound, Principal, ActiveToggle): __tablename__ = "services" __table_args__ = ({"schema": "authn"},) - name = Column(String(120), unique=True, nullable=False) - service_keys = relationship( + name: str = acol(storage=S(String(120), unique=True, nullable=False)) + _service_keys = relationship( "auto_authn.v2.orm.tables.ServiceKey", - back_populates="service", + back_populates="_service", cascade="all, delete-orphan", ) + service_keys: list["ServiceKey"] = vcol( + read_producer=lambda obj, _ctx: obj._service_keys, + io=IO(out_verbs=("read", "list")), + ) __all__ = ["Service"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/service_key.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/service_key.py index b98bc7d051..1a0ee98de5 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/service_key.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/service_key.py @@ -2,8 +2,19 @@ from __future__ import annotations -from autoapi.v2.tables import ApiKey as ApiKeyBase -from autoapi.v2.types import Column, ForeignKey, PgUUID, UniqueConstraint, relationship +import hashlib +import secrets + +from autoapi.v3.tables import ApiKey as ApiKeyBase +from autoapi.v3.types import PgUUID, UniqueConstraint, relationship +from autoapi.v3.specs import IO, S, acol, vcol +from autoapi.v3.specs.storage_spec import ForeignKeySpec +from autoapi.v3 import hook_ctx +from uuid import UUID +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from .service import Service class ServiceKey(ApiKeyBase): @@ -12,18 +23,46 @@ class ServiceKey(ApiKeyBase): UniqueConstraint("digest"), {"extend_existing": True, "schema": "authn"}, ) - service_id = Column( - PgUUID(as_uuid=True), - ForeignKey("authn.services.id"), - index=True, - nullable=False, + service_id: UUID = acol( + storage=S( + PgUUID(as_uuid=True), + fk=ForeignKeySpec(target="authn.services.id"), + index=True, + nullable=False, + ) ) - service = relationship( + _service = relationship( "auto_authn.v2.orm.tables.Service", - back_populates="service_keys", + back_populates="_service_keys", lazy="joined", ) + service: "Service" = vcol( + read_producer=lambda obj, _ctx: obj._service, + io=IO(out_verbs=("read", "list")), + ) + + @hook_ctx(ops="create", phase="PRE_HANDLER") + async def _generate_digest(cls, ctx): + payload = ctx.get("payload") or {} + token = secrets.token_urlsafe(32) + payload["digest"] = hashlib.sha256(token.encode()).hexdigest() + ctx["raw_key"] = token + + @hook_ctx(ops="create", phase="POST_RESPONSE") + async def _return_raw_key(cls, ctx): + raw = ctx.get("raw_key") + result = ctx.get("result") + if not raw or result is None: + return + if hasattr(result, "model_dump"): + data = result.model_dump() + elif hasattr(result, "dict") and callable(result.dict): + data = result.dict() # type: ignore[call-arg] + else: + data = dict(result) + data["raw_key"] = raw + ctx["result"] = data __all__ = ["ServiceKey"] diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/tables.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/tables.py index 7d3cf16314..a185f5a5f8 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/tables.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/tables.py @@ -2,7 +2,7 @@ from __future__ import annotations -from autoapi.v2 import Base +from autoapi.v3.tables import Base from .api_key import ApiKey from .auth_code import AuthCode diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/tenant.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/tenant.py index e4b6c6137b..3735b97fb7 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/tenant.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/tenant.py @@ -4,9 +4,10 @@ import uuid -from autoapi.v2.tables import Tenant as TenantBase -from autoapi.v2.types import Column, String -from autoapi.v2.mixins import Bootstrappable +from autoapi.v3.tables import Tenant as TenantBase +from autoapi.v3.mixins import Bootstrappable +from autoapi.v3.specs import acol, S +from autoapi.v3.types import String class Tenant(TenantBase, Bootstrappable): @@ -16,8 +17,8 @@ class Tenant(TenantBase, Bootstrappable): "schema": "authn", }, ) - name = Column(String, nullable=False, unique=True) - email = Column(String, nullable=False, unique=True) + name: str = acol(storage=S(String, nullable=False, unique=True)) + email: str = acol(storage=S(String, nullable=False, unique=True)) DEFAULT_ROWS = [ { "id": uuid.UUID("FFFFFFFF-0000-0000-0000-000000000000"), diff --git a/pkgs/standards/auto_authn/auto_authn/v2/orm/user.py b/pkgs/standards/auto_authn/auto_authn/v2/orm/user.py index 658dd33044..b448b32e18 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/orm/user.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/orm/user.py @@ -4,21 +4,44 @@ import uuid -from autoapi.v2.tables import User as UserBase -from autoapi.v2.types import Column, LargeBinary, String, relationship +from autoapi.v3.tables import User as UserBase +from autoapi.v3 import op_ctx, hook_ctx +from autoapi.v3.types import LargeBinary, String, relationship +from autoapi.v3.specs import IO, S, acol, vcol +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from .api_key import ApiKey class User(UserBase): """Human principal with authentication credentials.""" __table_args__ = ({"extend_existing": True, "schema": "authn"},) - email = Column(String(120), nullable=False, unique=True) - password_hash = Column(LargeBinary(60)) - api_keys = relationship( + email: str = acol(storage=S(String(120), nullable=False, unique=True)) + password_hash: bytes | None = acol(storage=S(LargeBinary(60))) + _api_keys = relationship( "auto_authn.v2.orm.tables.ApiKey", - back_populates="user", + back_populates="_user", cascade="all, delete-orphan", ) + api_keys: list["ApiKey"] = vcol( + read_producer=lambda obj, _ctx: obj._api_keys, + io=IO(out_verbs=("read", "list")), + ) + + @hook_ctx(ops=("create", "update"), phase="PRE_HANDLER") + async def _hash_password(cls, ctx): + payload = ctx.get("payload") or {} + plain = payload.pop("password", None) + if plain: + from ..crypto import hash_pw + + payload["password_hash"] = hash_pw(plain) + + @op_ctx(alias="register", target="create", arity="collection") + def register(cls, ctx): + pass @classmethod def new(cls, tenant_id: uuid.UUID, username: str, email: str, password: str): From 24044d84205dbbfcd322fb758cbfe852310ff094 Mon Sep 17 00:00:00 2001 From: cobycloud <25079070+cobycloud@users.noreply.github.com> Date: Tue, 26 Aug 2025 02:15:33 -0500 Subject: [PATCH 2/2] feat: route auth flows through ORM operations --- .../auto_authn/auto_authn/v2/rfc7009.py | 12 +++- .../auto_authn/auto_authn/v2/rfc8628.py | 55 ++++++++++------- .../auto_authn/auto_authn/v2/rfc9126.py | 60 +++++++++++-------- .../auto_authn/v2/routers/authn/login.py | 44 ++++++++------ .../auto_authn/v2/routers/authn/logout.py | 12 +++- .../auto_authn/v2/routers/authn/register.py | 54 ++++++++++------- .../auto_authn/v2/routers/authz/oidc.py | 15 +++-- .../auto_authn/v2/routers/authz/rfc6749.py | 58 +++++++++--------- pkgs/standards/auto_authn/tests/conftest.py | 6 +- .../auto_authn/tests/i9n/test_rfc8628.py | 4 +- 10 files changed, 192 insertions(+), 128 deletions(-) diff --git a/pkgs/standards/auto_authn/auto_authn/v2/rfc7009.py b/pkgs/standards/auto_authn/auto_authn/v2/rfc7009.py index 998d1e8166..599c2038c0 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/rfc7009.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/rfc7009.py @@ -11,9 +11,12 @@ from typing import Final, Set -from fastapi import APIRouter, FastAPI, Form, HTTPException, status +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession from .runtime_cfg import settings +from .fastapi_deps import get_async_db +from .orm.tables import RevokedToken RFC7009_SPEC_URL: Final = "https://www.rfc-editor.org/rfc/rfc7009" @@ -55,11 +58,16 @@ def reset_revocations() -> None: @router.post("/revoke") -async def revoke(token: str = Form(...), token_type_hint: str | None = Form(None)): +async def revoke( + token: str = Form(...), + token_type_hint: str | None = Form(None), + db: AsyncSession = Depends(get_async_db), +): """RFC 7009 token revocation endpoint.""" if not settings.enable_rfc7009: raise HTTPException(status.HTTP_404_NOT_FOUND, "revocation disabled") + await RevokedToken.handlers.revoke.core({"db": db, "payload": {"token": token}}) revoke_token(token) return {} diff --git a/pkgs/standards/auto_authn/auto_authn/v2/rfc8628.py b/pkgs/standards/auto_authn/auto_authn/v2/rfc8628.py index fde9cfc9c8..8ffc469ddf 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/rfc8628.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/rfc8628.py @@ -15,12 +15,15 @@ import re import secrets import string -from typing import Any, Dict, Final, Literal +from typing import Final, Literal -from fastapi import APIRouter, FastAPI, HTTPException, status +from fastapi import APIRouter, Depends, FastAPI, HTTPException, status from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession from .runtime_cfg import settings +from .fastapi_deps import get_async_db +from .orm.tables import DeviceCode # Character set for user_code per RFC 8628 §6.1 (uppercase letters and digits) _USER_CODE_CHARSET: Final = string.ascii_uppercase + string.digits @@ -31,12 +34,7 @@ RFC8628_SPEC_URL: Final = "https://www.rfc-editor.org/rfc/rfc8628" -# --------------------------------------------------------------------------- -# In-memory device authorization store -# --------------------------------------------------------------------------- router = APIRouter() - -DEVICE_CODES: Dict[str, Dict[str, Any]] = {} DEVICE_VERIFICATION_URI = "https://example.com/device" DEVICE_CODE_EXPIRES_IN = 600 # seconds DEVICE_CODE_INTERVAL = 5 # seconds @@ -69,7 +67,9 @@ class DeviceGrantForm(BaseModel): @router.post("/device_authorization", response_model=DeviceAuthOut) -async def device_authorization(body: DeviceAuthIn) -> DeviceAuthOut: +async def device_authorization( + body: DeviceAuthIn, db: AsyncSession = Depends(get_async_db) +) -> DeviceAuthOut: """Issue a new device and user code pair.""" if not settings.enable_rfc8628: @@ -80,15 +80,18 @@ async def device_authorization(body: DeviceAuthIn) -> DeviceAuthOut: verification_uri = DEVICE_VERIFICATION_URI verification_uri_complete = f"{verification_uri}?user_code={user_code}" expires_at = datetime.utcnow() + timedelta(seconds=DEVICE_CODE_EXPIRES_IN) - DEVICE_CODES[device_code] = { - "user_code": user_code, - "client_id": body.client_id, - "expires_at": expires_at, - "interval": DEVICE_CODE_INTERVAL, - "authorized": False, - "sub": None, - "tid": None, - } + await DeviceCode.handlers.device_authorization.core( + { + "db": db, + "payload": { + "device_code": device_code, + "user_code": user_code, + "client_id": body.client_id, + "expires_at": expires_at, + "interval": DEVICE_CODE_INTERVAL, + }, + } + ) return DeviceAuthOut( device_code=device_code, user_code=user_code, @@ -99,13 +102,20 @@ async def device_authorization(body: DeviceAuthIn) -> DeviceAuthOut: ) -def approve_device_code(device_code: str, sub: str, tid: str) -> None: +async def approve_device_code( + device_code: str, sub: str, tid: str, db: AsyncSession +) -> None: """Mark a device code as authorized (testing helper).""" - if device_code in DEVICE_CODES: - DEVICE_CODES[device_code]["authorized"] = True - DEVICE_CODES[device_code]["sub"] = sub - DEVICE_CODES[device_code]["tid"] = tid + obj = await DeviceCode.handlers.read.core({"db": db, "obj_id": device_code}) + if obj: + await DeviceCode.handlers.update.core( + { + "db": db, + "obj": obj, + "payload": {"authorized": True, "user_id": sub, "tenant_id": tid}, + } + ) def include_rfc8628(app: FastAPI) -> None: @@ -157,7 +167,6 @@ def generate_device_code() -> str: "DeviceAuthIn", "DeviceAuthOut", "DeviceGrantForm", - "DEVICE_CODES", "approve_device_code", "include_rfc8628", "router", diff --git a/pkgs/standards/auto_authn/auto_authn/v2/rfc9126.py b/pkgs/standards/auto_authn/auto_authn/v2/rfc9126.py index 7064ae6e61..17c49f7ce8 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/rfc9126.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/rfc9126.py @@ -12,14 +12,15 @@ import uuid from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Final, Tuple +from typing import Any, Dict, Final -from fastapi import APIRouter, FastAPI, HTTPException, Request, status +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession from . import runtime_cfg - -# In-memory storage mapping request_uri -> (params, expiry) -_PAR_STORE: Dict[str, Tuple[Dict[str, Any], datetime]] = {} +from .fastapi_deps import get_async_db +from .orm.tables import PushedAuthorizationRequest router = APIRouter() @@ -28,46 +29,57 @@ RFC9126_SPEC_URL: Final = "https://www.rfc-editor.org/rfc/rfc9126" -def store_par_request( - params: Dict[str, Any], expires_in: int = DEFAULT_PAR_EXPIRY +async def store_par_request( + params: Dict[str, Any], + db: AsyncSession, + expires_in: int = DEFAULT_PAR_EXPIRY, ) -> str: - """Store *params* and return a unique ``request_uri``. + """Store *params* and return a unique ``request_uri``.""" - Parameters expire after *expires_in* seconds. - """ request_uri = f"urn:ietf:params:oauth:request_uri:{uuid.uuid4()}" - _PAR_STORE[request_uri] = ( - params, - datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in), + expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in) + await PushedAuthorizationRequest.handlers.par.core( + { + "db": db, + "payload": { + "request_uri": request_uri, + "params": params, + "expires_at": expires_at, + }, + } ) return request_uri -def get_par_request(request_uri: str) -> Dict[str, Any] | None: +async def get_par_request(request_uri: str, db: AsyncSession) -> Dict[str, Any] | None: """Retrieve parameters for *request_uri* if present and not expired.""" - record = _PAR_STORE.get(request_uri) - if not record: + + obj = await db.get(PushedAuthorizationRequest, request_uri) + if not obj: return None - params, expiry = record - if datetime.now(tz=timezone.utc) > expiry: - del _PAR_STORE[request_uri] + if datetime.now(tz=timezone.utc) > obj.expires_at: + await PushedAuthorizationRequest.handlers.delete.core({"db": db, "obj": obj}) return None - return params + return obj.params -def reset_par_store() -> None: +async def reset_par_store(db: AsyncSession) -> None: """Clear stored pushed authorization requests (test helper).""" - _PAR_STORE.clear() + + await db.execute(delete(PushedAuthorizationRequest)) + await db.commit() @router.post("/par", status_code=status.HTTP_201_CREATED) -async def pushed_authorization_request(request: Request): +async def pushed_authorization_request( + request: Request, db: AsyncSession = Depends(get_async_db) +): """Endpoint for RFC 9126 pushed authorization requests.""" if not runtime_cfg.settings.enable_rfc9126: raise HTTPException(status.HTTP_404_NOT_FOUND, "PAR disabled") form = await request.form() - request_uri = store_par_request(dict(form)) + request_uri = await store_par_request(dict(form), db) return {"request_uri": request_uri, "expires_in": DEFAULT_PAR_EXPIRY} diff --git a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/login.py b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/login.py index 5fb5ed8a02..29cbf1f22b 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/login.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/login.py @@ -1,19 +1,18 @@ from __future__ import annotations -from datetime import datetime import secrets -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, Request from fastapi.responses import JSONResponse from sqlalchemy.ext.asyncio import AsyncSession from ...oidc_id_token import mint_id_token from ...rfc8414_metadata import ISSUER from ...fastapi_deps import get_async_db -from ...backends import AuthError +from ...orm.tables import AuthSession from ..schemas import CredsIn, TokenPair -from ..shared import _require_tls, _jwt, _pwd_backend, SESSIONS +from ..shared import _require_tls, _jwt, SESSIONS from . import router @@ -23,29 +22,40 @@ async def login( body: CredsIn, request: Request, db: AsyncSession = Depends(get_async_db) ): _require_tls(request) + session_id = secrets.token_urlsafe(16) try: - user = await _pwd_backend.authenticate(db, body.identifier, body.password) - except AuthError: - raise HTTPException(status.HTTP_404_NOT_FOUND, "invalid credentials") + session = await AuthSession.handlers.login.core( + { + "db": db, + "payload": { + "id": session_id, + "username": body.identifier, + "password": body.password, + }, + } + ) + except HTTPException: + raise access, refresh = await _jwt.async_sign_pair( - sub=str(user.id), tid=str(user.tenant_id), scope="openid profile email" + sub=str(session.user_id), + tid=str(session.tenant_id), + scope="openid profile email", ) - session_id = secrets.token_urlsafe(16) - SESSIONS[session_id] = { - "sub": str(user.id), - "tid": str(user.tenant_id), - "username": user.username, - "auth_time": datetime.utcnow(), + SESSIONS[session.id] = { + "sub": str(session.user_id), + "tid": str(session.tenant_id), + "username": session.username, + "auth_time": session.auth_time, } id_token = mint_id_token( - sub=str(user.id), + sub=str(session.user_id), aud=ISSUER, nonce=secrets.token_urlsafe(8), issuer=ISSUER, - sid=session_id, + sid=session.id, ) pair = TokenPair(access_token=access, refresh_token=refresh, id_token=id_token) response = JSONResponse(pair.model_dump()) - response.set_cookie("sid", session_id, httponly=True, samesite="lax") + response.set_cookie("sid", session.id, httponly=True, samesite="lax") return response diff --git a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/logout.py b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/logout.py index e2203b27e3..7df446bf89 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/logout.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/logout.py @@ -1,9 +1,12 @@ from __future__ import annotations -from fastapi import HTTPException, Request, Response, status +from fastapi import Depends, HTTPException, Request, Response, status +from sqlalchemy.ext.asyncio import AsyncSession from ...oidc_id_token import verify_id_token from ...rfc8414_metadata import ISSUER +from ...fastapi_deps import get_async_db +from ...orm.tables import AuthSession from ..schemas import LogoutIn from ..shared import _require_tls, _front_channel_logout, _back_channel_logout, SESSIONS @@ -12,7 +15,9 @@ @router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) -async def logout(body: LogoutIn, request: Request): +async def logout( + body: LogoutIn, request: Request, db: AsyncSession = Depends(get_async_db) +): _require_tls(request) try: claims = verify_id_token(body.id_token_hint, issuer=ISSUER, audience=ISSUER) @@ -22,6 +27,9 @@ async def logout(body: LogoutIn, request: Request): ) from exc sid = claims.get("sid") if sid: + session = await AuthSession.handlers.read.core({"db": db, "obj_id": sid}) + if session: + await AuthSession.handlers.logout.core({"db": db, "obj": session}) SESSIONS.pop(sid, None) await _front_channel_logout(sid) await _back_channel_logout(sid) diff --git a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/register.py b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/register.py index 8044030aec..b48f1114c9 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/register.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/routers/authn/register.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime import secrets from fastapi import Depends, HTTPException, Request, status @@ -8,10 +7,9 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ...crypto import hash_pw from ...oidc_id_token import mint_id_token from ...rfc8414_metadata import ISSUER -from ...orm.tables import Tenant, User +from ...orm.tables import AuthSession, Tenant, User from ...fastapi_deps import get_async_db from ..schemas import RegisterIn, TokenPair @@ -41,16 +39,29 @@ async def register( ) if tenant is None: raise HTTPException(status.HTTP_404_NOT_FOUND, "tenant not found") - user = User( - tenant_id=tenant.id, - username=body.username, - email=body.email, - password_hash=hash_pw(body.password), + await User.handlers.register.core( + { + "db": db, + "payload": { + "tenant_id": tenant.id, + "username": body.username, + "email": body.email, + "password": body.password, + }, + } + ) + session_id = secrets.token_urlsafe(16) + session = await AuthSession.handlers.login.core( + { + "db": db, + "payload": { + "id": session_id, + "username": body.username, + "password": body.password, + }, + } ) - db.add(user) - await db.commit() except Exception as exc: - await db.rollback() if isinstance(exc, HTTPException): raise from autoapi.v2.error import IntegrityError @@ -62,23 +73,24 @@ async def register( ) from exc access, refresh = await _jwt.async_sign_pair( - sub=str(user.id), tid=str(tenant.id), scope="openid profile email" + sub=str(session.user_id), + tid=str(session.tenant_id), + scope="openid profile email", ) - session_id = secrets.token_urlsafe(16) - SESSIONS[session_id] = { - "sub": str(user.id), - "tid": str(tenant.id), - "username": user.username, - "auth_time": datetime.utcnow(), + SESSIONS[session.id] = { + "sub": str(session.user_id), + "tid": str(session.tenant_id), + "username": session.username, + "auth_time": session.auth_time, } id_token = mint_id_token( - sub=str(user.id), + sub=str(session.user_id), aud=ISSUER, nonce=secrets.token_urlsafe(8), issuer=ISSUER, - sid=session_id, + sid=session.id, ) pair = TokenPair(access_token=access, refresh_token=refresh, id_token=id_token) response = JSONResponse(pair.model_dump()) - response.set_cookie("sid", session_id, httponly=True, samesite="lax") + response.set_cookie("sid", session.id, httponly=True, samesite="lax") return response diff --git a/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/oidc.py b/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/oidc.py index a20843acd9..d583e2c3c9 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/oidc.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/oidc.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ...fastapi_deps import get_async_db -from ...orm.tables import Client, User +from ...orm.tables import AuthCode, Client, User from ...oidc_id_token import mint_id_token, oidc_hash from ...rfc8414_metadata import ISSUER from ...rfc8252 import is_native_redirect_uri @@ -108,10 +108,11 @@ async def authorize( ) if "code" in rts: code = secrets.token_urlsafe(32) - AUTH_CODES[code] = { - "sub": user_sub, - "tid": tenant_id, - "client_id": client_id, + payload = { + "code": code, + "user_id": UUID(user_sub), + "tenant_id": UUID(tenant_id), + "client_id": UUID(client_id), "redirect_uri": redirect_uri, "code_challenge": code_challenge, "nonce": nonce, @@ -119,7 +120,9 @@ async def authorize( "expires_at": datetime.utcnow() + timedelta(minutes=10), } if requested_claims: - AUTH_CODES[code]["claims"] = requested_claims + payload["claims"] = requested_claims + await AuthCode.handlers.create.core({"db": db, "payload": payload}) + AUTH_CODES[code] = payload params.append(("code", code)) if "token" in rts: from ..shared import _jwt diff --git a/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/rfc6749.py b/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/rfc6749.py index af989348e0..d28ea9bc1e 100644 --- a/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/rfc6749.py +++ b/pkgs/standards/auto_authn/auto_authn/v2/routers/authz/rfc6749.py @@ -1,7 +1,6 @@ from __future__ import annotations import secrets -from uuid import UUID from datetime import datetime from typing import Any @@ -14,7 +13,7 @@ from ...backends import AuthError from ...fastapi_deps import get_async_db -from ...orm.tables import Client, User +from ...orm.tables import AuthCode, Client, DeviceCode, User from ...rfc8707 import extract_resource from ...runtime_cfg import settings from ...rfc6749 import ( @@ -25,7 +24,7 @@ is_enabled as rfc6749_enabled, ) from ...rfc7636_pkce import verify_code_challenge -from ...rfc8628 import DEVICE_CODES, DeviceGrantForm +from ...rfc8628 import DeviceGrantForm from ...oidc_id_token import mint_id_token, oidc_hash from ...rfc8414_metadata import ISSUER @@ -35,7 +34,7 @@ RefreshIn, TokenPair, ) -from ..shared import _require_tls, _jwt, _pwd_backend, _ALLOWED_GRANT_TYPES, AUTH_CODES +from ..shared import _require_tls, _jwt, _pwd_backend, _ALLOWED_GRANT_TYPES from . import router @@ -134,72 +133,75 @@ async def token( parsed = AuthorizationCodeGrantForm(**data) except ValidationError as exc: raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, exc.errors()) - record = AUTH_CODES.pop(parsed.code, None) + auth_code = await AuthCode.handlers.read.core({"db": db, "obj_id": parsed.code}) if ( - record is None - or record["client_id"] != parsed.client_id - or record["redirect_uri"] != parsed.redirect_uri - or datetime.utcnow() > record["expires_at"] + auth_code is None + or str(auth_code.client_id) != parsed.client_id + or auth_code.redirect_uri != parsed.redirect_uri + or datetime.utcnow() > auth_code.expires_at ): return JSONResponse( {"error": "invalid_grant"}, status_code=status.HTTP_400_BAD_REQUEST ) - if record.get("code_challenge"): + if auth_code.code_challenge: if not parsed.code_verifier or not verify_code_challenge( - parsed.code_verifier, record["code_challenge"] + parsed.code_verifier, auth_code.code_challenge ): return JSONResponse( {"error": "invalid_grant"}, status_code=status.HTTP_400_BAD_REQUEST ) jwt_kwargs = {"aud": aud} if aud else {} - if record.get("scope"): - jwt_kwargs["scope"] = record["scope"] + if auth_code.scope: + jwt_kwargs["scope"] = auth_code.scope access, refresh = await _jwt.async_sign_pair( - sub=record["sub"], tid=record["tid"], **jwt_kwargs + sub=str(auth_code.user_id), tid=str(auth_code.tenant_id), **jwt_kwargs ) - nonce = record.get("nonce") or secrets.token_urlsafe(8) + nonce = auth_code.nonce or secrets.token_urlsafe(8) extra_claims: dict[str, Any] = { - "tid": record["tid"], + "tid": str(auth_code.tenant_id), "typ": "id", "at_hash": oidc_hash(access), } - if record.get("claims") and "id_token" in record["claims"]: - user_obj = await db.get(User, UUID(record["sub"])) - idc = record["claims"]["id_token"] + if auth_code.claims and "id_token" in auth_code.claims: + user_obj = await db.get(User, auth_code.user_id) + idc = auth_code.claims["id_token"] if "email" in idc: extra_claims["email"] = user_obj.email if user_obj else "" if any(k in idc for k in ("name", "preferred_username")): extra_claims["name"] = user_obj.username if user_obj else "" id_token = mint_id_token( - sub=record["sub"], + sub=str(auth_code.user_id), aud=parsed.client_id, nonce=nonce, issuer=ISSUER, **extra_claims, ) + await AuthCode.handlers.exchange.core({"db": db, "obj": auth_code}) return TokenPair(access_token=access, refresh_token=refresh, id_token=id_token) if grant_type == "urn:ietf:params:oauth:grant-type:device_code": try: parsed = DeviceGrantForm(**data) except ValidationError as exc: raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, exc.errors()) - record = DEVICE_CODES.get(parsed.device_code) - if not record or record["client_id"] != parsed.client_id: + device_obj = await DeviceCode.handlers.read.core( + {"db": db, "obj_id": parsed.device_code} + ) + if not device_obj or str(device_obj.client_id) != parsed.client_id: raise HTTPException(status.HTTP_400_BAD_REQUEST, {"error": "invalid_grant"}) - if datetime.utcnow() > record["expires_at"]: - DEVICE_CODES.pop(parsed.device_code, None) + if datetime.utcnow() > device_obj.expires_at: + await DeviceCode.handlers.delete.core({"db": db, "obj": device_obj}) raise HTTPException(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) - if not record.get("authorized"): + if not device_obj.authorized: raise HTTPException( status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} ) jwt_kwargs = {"aud": aud} if aud else {} access, refresh = await _jwt.async_sign_pair( - sub=record.get("sub", "device-user"), - tid=record.get("tid", "device-tenant"), + sub=str(device_obj.user_id or "device-user"), + tid=str(device_obj.tenant_id or "device-tenant"), **jwt_kwargs, ) - DEVICE_CODES.pop(parsed.device_code, None) + await DeviceCode.handlers.delete.core({"db": db, "obj": device_obj}) return TokenPair(access_token=access, refresh_token=refresh) if rfc6749_enabled(): return JSONResponse( diff --git a/pkgs/standards/auto_authn/tests/conftest.py b/pkgs/standards/auto_authn/tests/conftest.py index 4921e170bd..ee01faa185 100644 --- a/pkgs/standards/auto_authn/tests/conftest.py +++ b/pkgs/standards/auto_authn/tests/conftest.py @@ -169,20 +169,20 @@ def enable_rfc8414(): @pytest.fixture -def enable_rfc9126(): +def enable_rfc9126(db_session): """Enable RFC 9126 pushed authorization requests for tests.""" from auto_authn.v2.runtime_cfg import settings from auto_authn.v2.rfc9126 import reset_par_store, include_rfc9126 original = settings.enable_rfc9126 settings.enable_rfc9126 = True - reset_par_store() + asyncio.get_event_loop().run_until_complete(reset_par_store(db_session)) include_rfc9126(app) try: yield finally: settings.enable_rfc9126 = original - reset_par_store() + asyncio.get_event_loop().run_until_complete(reset_par_store(db_session)) @pytest.fixture diff --git a/pkgs/standards/auto_authn/tests/i9n/test_rfc8628.py b/pkgs/standards/auto_authn/tests/i9n/test_rfc8628.py index f51661b1aa..4f81926b5c 100644 --- a/pkgs/standards/auto_authn/tests/i9n/test_rfc8628.py +++ b/pkgs/standards/auto_authn/tests/i9n/test_rfc8628.py @@ -33,7 +33,7 @@ async def test_device_authorization_endpoint(async_client: AsyncClient) -> None: @pytest.mark.integration @pytest.mark.asyncio -async def test_device_token_polling(async_client: AsyncClient) -> None: +async def test_device_token_polling(async_client: AsyncClient, db_session) -> None: """Token endpoint should poll until the device code is approved.""" auth_resp = await async_client.post( "/device_authorization", data={"client_id": "test-client"} @@ -50,7 +50,7 @@ async def test_device_token_polling(async_client: AsyncClient) -> None: from auto_authn.v2.rfc8628 import approve_device_code - approve_device_code(device_code, sub="user", tid="tenant") + await approve_device_code(device_code, sub="user", tid="tenant", db=db_session) success = await async_client.post("/token", data=payload) assert success.status_code == status.HTTP_200_OK data = success.json()