Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pkgs/standards/autoapi/autoapi/v3/mixins/_RowBound.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# autoapi/v2/mixins/bound.py
# autoapi/v3/mixins/_RowBound.py
from __future__ import annotations

from typing import Any, Mapping, Sequence

from autoapi.v2.hooks import Phase
from autoapi.v2.types import HookProvider
from autoapi.v2.jsonrpc_models import HTTP_ERROR_MESSAGES, create_standardized_error
from ..types import HookProvider
from ..impl.runtime.errors import (
HTTP_ERROR_MESSAGES,
create_standardized_error_from_status,
)


class _RowBound(HookProvider):
Expand All @@ -32,7 +34,7 @@ def __autoapi_register_hooks__(cls, api) -> None:
return

for op in ("read", "list"):
api.register_hook(model=cls, phase=Phase.POST_HANDLER, op=op)(
api.register_hook(model=cls, phase="POST_HANDLER", op=op)(
cls._make_row_visibility_hook()
)

Expand All @@ -54,7 +56,7 @@ def _row_visibility_hook(ctx: Mapping[str, Any]) -> None:

# READ → invisible row → pretend 404
if not cls.is_visible(res, ctx):
http_exc, _, _ = create_standardized_error(
http_exc, _, _ = create_standardized_error_from_status(
404, message=HTTP_ERROR_MESSAGES[404]
)
raise http_exc
Expand Down
8 changes: 3 additions & 5 deletions pkgs/standards/autoapi/autoapi/v3/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
UUID,
uuid4,
)
from ..cfgs import AUTH_CONTEXT_KEY, USER_ID_KEY
from ..config.constants import CTX_USER_ID_KEY


def tzutcnow() -> dt.datetime: # default/on‑update factory
Expand Down Expand Up @@ -155,8 +155,7 @@ class OwnerBound:

@classmethod
def filter_for_ctx(cls, q, ctx):
auto_fields = ctx.get(AUTH_CONTEXT_KEY, {})
return q.filter(cls.owner_id == auto_fields.get(USER_ID_KEY))
return q.filter(cls.owner_id == ctx.get(CTX_USER_ID_KEY))


class UserBound: # membership rows
Expand All @@ -168,8 +167,7 @@ class UserBound: # membership rows

@classmethod
def filter_for_ctx(cls, q, ctx):
auto_fields = ctx.get(AUTH_CONTEXT_KEY, {})
return q.filter(cls.user_id == auto_fields.get(USER_ID_KEY))
return q.filter(cls.user_id == ctx.get(CTX_USER_ID_KEY))


# ────────── lifecycle --------------------------------------------------
Expand Down
20 changes: 8 additions & 12 deletions pkgs/standards/autoapi/autoapi/v3/mixins/ownable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import logging
from uuid import UUID

from ..hooks import Phase
from ..jsonrpc_models import create_standardized_error
from ..info_schema import check as _info_check
from ..types import Column, ForeignKey, PgUUID, declared_attr
from ..cfgs import AUTH_CONTEXT_KEY, INJECTED_FIELDS_KEY, USER_ID_KEY
from ..config.constants import CTX_USER_ID_KEY
from ..impl.runtime.errors import create_standardized_error_from_status

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +72,7 @@ def __autoapi_register_hooks__(cls, api):
pol = cls.__autoapi_owner_policy__

def _err(status: int, msg: str):
http_exc, _, _ = create_standardized_error(status, message=msg)
http_exc, _, _ = create_standardized_error_from_status(status, message=msg)
raise http_exc

def _ownable_before_create(ctx):
Expand All @@ -82,16 +81,14 @@ def _ownable_before_create(ctx):
# keep None so we can treat it as "missing" explicitly
params = params.model_dump()

auto_fields = ctx.get(AUTH_CONTEXT_KEY, {})
user_id = auto_fields.get(USER_ID_KEY)
user_id = ctx.get(CTX_USER_ID_KEY)
provided = params.get("owner_id")
missing = _is_missing(provided)

log.info(
"Ownable before_create policy=%s params=%s auto_fields=%s",
"Ownable before_create policy=%s params=%s",
pol,
params,
auto_fields,
)

if pol == OwnerPolicy.STRICT_SERVER:
Expand Down Expand Up @@ -126,8 +123,7 @@ def _ownable_before_update(ctx, obj):
_err(400, "owner_id is immutable.")

new_val = _normalize_uuid(params["owner_id"])
auto_fields = ctx.get(INJECTED_FIELDS_KEY, {})
user_id = _normalize_uuid(auto_fields.get(USER_ID_KEY))
user_id = _normalize_uuid(ctx.get(CTX_USER_ID_KEY))

log.info(
"Ownable before_update new_val=%s obj_owner=%s injected=%s",
Expand All @@ -142,9 +138,9 @@ def _ownable_before_update(ctx, obj):
):
_err(403, "Cannot transfer ownership.")

api.register_hook(model=cls, phase=Phase.PRE_TX_BEGIN, op="create")(
api.register_hook(model=cls, phase="PRE_TX_BEGIN", op="create")(
_ownable_before_create
)
api.register_hook(model=cls, phase=Phase.PRE_TX_BEGIN, op="update")(
api.register_hook(model=cls, phase="PRE_TX_BEGIN", op="update")(
_ownable_before_update
)
21 changes: 8 additions & 13 deletions pkgs/standards/autoapi/autoapi/v3/mixins/tenant_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from ._RowBound import _RowBound
from ..types import Column, ForeignKey, PgUUID, declared_attr
from ..hooks import Phase
from ..jsonrpc_models import create_standardized_error
from ..info_schema import check as _info_check
from ..cfgs import AUTH_CONTEXT_KEY, INJECTED_FIELDS_KEY, TENANT_ID_KEY
from ..config.constants import CTX_TENANT_ID_KEY
from ..impl.runtime.errors import create_standardized_error_from_status

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,8 +88,7 @@ def __tablename__(cls):
# -------------------------------------------------------------------
@staticmethod
def is_visible(obj, ctx) -> bool:
auto_fields = ctx.get(AUTH_CONTEXT_KEY, {})
ctx_tenant_id = auto_fields.get(TENANT_ID_KEY)
ctx_tenant_id = ctx.get(CTX_TENANT_ID_KEY)
return getattr(obj, "tenant_id", None) == ctx_tenant_id

# -------------------------------------------------------------------
Expand All @@ -101,7 +99,7 @@ def __autoapi_register_hooks__(cls, api):
pol = cls.__autoapi_tenant_policy__

def _err(code: int, msg: str):
http_exc, _, _ = create_standardized_error(code, message=msg)
http_exc, _, _ = create_standardized_error_from_status(code, message=msg)
raise http_exc

# INSERT
Expand All @@ -112,9 +110,7 @@ def _tenantbound_before_create(ctx):
# and rely on _is_missing() to decide.
params = params.model_dump()

auto_fields = ctx.get(INJECTED_FIELDS_KEY, {})
print(f"\n🚧{auto_fields}")
injected_tid = auto_fields.get(TENANT_ID_KEY)
injected_tid = ctx.get(CTX_TENANT_ID_KEY)
print(f"\n🚧🚧{injected_tid}")
provided = params.get("tenant_id")
missing = _is_missing(provided)
Expand Down Expand Up @@ -155,8 +151,7 @@ def _tenantbound_before_update(ctx, obj):
_err(400, "tenant_id is immutable.")

new_val = _normalize_uuid(provided)
auto_fields = ctx.get(INJECTED_FIELDS_KEY, {})
injected_tid = _normalize_uuid(auto_fields.get(TENANT_ID_KEY))
injected_tid = _normalize_uuid(ctx.get(CTX_TENANT_ID_KEY))

log.info(
"TenantBound before_update new_val=%s obj_tid=%s injected=%s",
Expand All @@ -173,9 +168,9 @@ def _tenantbound_before_update(ctx, obj):
_err(403, "Cannot switch tenant context.")

# Register hooks
api.register_hook(model=cls, phase=Phase.PRE_TX_BEGIN, op="create")(
api.register_hook(model=cls, phase="PRE_TX_BEGIN", op="create")(
_tenantbound_before_create
)
api.register_hook(model=cls, phase=Phase.PRE_TX_BEGIN, op="update")(
api.register_hook(model=cls, phase="PRE_TX_BEGIN", op="update")(
_tenantbound_before_update
)
19 changes: 14 additions & 5 deletions pkgs/standards/autoapi/autoapi/v3/mixins/upsertable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from __future__ import annotations
from typing import Any, Mapping, Sequence, Optional, Tuple
from sqlalchemy import and_, inspect as sa_inspect
from autoapi.v2.hooks import Phase
from autoapi.v2.types import Session, HookProvider
from ..types import Session, HookProvider


class Upsertable(HookProvider):
"""
Expand All @@ -12,13 +12,14 @@ class Upsertable(HookProvider):
• Else if all PK parts are present -> decide by PK
• Else -> no rewrite
"""

__upsert_keys__: Sequence[str] | None = None # optional natural key list

@classmethod
def __autoapi_register_hooks__(cls, api) -> None:
model = cls.__tablename__
for op in ("create", "update", "replace"):
api.register_hook(Phase.PRE_TX_BEGIN, model=model, op=op)(
api.register_hook("PRE_TX_BEGIN", model=model, op=op)(
cls._make_upsert_rewrite_hook(op)
)

Expand Down Expand Up @@ -57,7 +58,10 @@ async def _rewrite(ctx: Mapping[str, Any]) -> None:

return _rewrite

def _extract_values(p: Mapping[str, Any], names: Sequence[str]) -> Optional[Tuple[Any, ...]]:

def _extract_values(
p: Mapping[str, Any], names: Sequence[str]
) -> Optional[Tuple[Any, ...]]:
vals = []
for n in names:
v = p.get(n)
Expand All @@ -66,19 +70,24 @@ def _extract_values(p: Mapping[str, Any], names: Sequence[str]) -> Optional[Tupl
vals.append(v)
return tuple(vals)

def _exists_by_names(model, db: Session, names: Sequence[str], vals: Tuple[Any, ...]) -> bool:

def _exists_by_names(
model, db: Session, names: Sequence[str], vals: Tuple[Any, ...]
) -> bool:
q = db.query(model)
for n, v in zip(names, vals):
q = q.filter(getattr(model, n) == v)
return db.query(q.exists()).scalar() is True


def _exists_by_pk(model, db: Session, pk_cols, pk_vals: Tuple[Any, ...]) -> bool:
if len(pk_cols) == 1:
# fast path
return db.get(model, pk_vals[0]) is not None
conds = [getattr(model, c.key) == v for c, v in zip(pk_cols, pk_vals)]
return db.query(db.query(model).filter(and_(*conds)).exists()).scalar() is True


def _rewrite_by_existence(ctx, tab: str, verb: str, exists: bool) -> None:
if verb == "create" and exists:
ctx["env"].method = f"{tab}.update"
Expand Down
Loading