diff --git a/src/app/db/models/user_role.py b/src/app/db/models/user_role.py index 44855bae..63f4b2e0 100644 --- a/src/app/db/models/user_role.py +++ b/src/app/db/models/user_role.py @@ -5,7 +5,7 @@ from uuid import UUID # noqa: TCH003 from advanced_alchemy.base import UUIDAuditBase -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -18,7 +18,10 @@ class UserRole(UUIDAuditBase): """User Role.""" __tablename__ = "user_account_role" - __table_args__ = {"comment": "Links a user to a specific role."} + __table_args__ = ( + UniqueConstraint("user_id", "role_id"), # Avoid multiple assignments of the same role + {"comment": "Links a user to a specific role."}, + ) user_id: Mapped[UUID] = mapped_column(ForeignKey("user_account.id", ondelete="cascade"), nullable=False) role_id: Mapped[UUID] = mapped_column(ForeignKey("role.id", ondelete="cascade"), nullable=False) assigned_at: Mapped[datetime] = mapped_column(default=datetime.now(UTC)) diff --git a/src/app/domain/accounts/controllers/user_role.py b/src/app/domain/accounts/controllers/user_role.py index fb794318..5e131667 100644 --- a/src/app/domain/accounts/controllers/user_role.py +++ b/src/app/domain/accounts/controllers/user_role.py @@ -1,4 +1,5 @@ """User Routes.""" + from __future__ import annotations from litestar import Controller, post @@ -44,8 +45,11 @@ async def assign_role( """Create a new migration role.""" role_id = (await roles_service.get_one(slug=role_slug)).id user_obj = await users_service.get_one(email=data.user_name) - if all(user_role.role_id != role_id for user_role in user_obj.roles): - obj, created = await user_roles_service.get_or_upsert(role_id=role_id, user_id=user_obj.id) + # if all(user_role.role_id != role_id for user_role in user_obj.roles): + obj, created = await user_roles_service.get_or_upsert( + role_id=role_id, + user_id=user_obj.id, + ) if created: return Message(message=f"Successfully assigned the '{obj.role_slug}' role to {obj.user_email}.") return Message(message=f"User {obj.user_email} already has the '{obj.role_slug}' role.") diff --git a/tests/integration/test_account_role.py b/tests/integration/test_account_role.py index 1b1fcf54..e0f5efa3 100644 --- a/tests/integration/test_account_role.py +++ b/tests/integration/test_account_role.py @@ -1,11 +1,12 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING import pytest if TYPE_CHECKING: - from httpx import AsyncClient + from httpx import AsyncClient, Response pytestmark = pytest.mark.anyio @@ -66,3 +67,28 @@ async def test_superuser_role_access( response = await client.get("/api/teams", headers=user_token_headers) assert response.status_code == 200 assert int(response.json()["total"]) == 0 + + +@pytest.mark.parametrize("n_requests", [1, 4]) +async def test_assign_role_concurrent( + client: "AsyncClient", + superuser_token_headers: dict[str, str], + n_requests: int, +) -> None: + async def post() -> Response: + return await client.post( + "/api/roles/superuser/assign", + json={"userName": "user@example.com"}, + headers=superuser_token_headers, + ) + + responses = await asyncio.gather(*[post() for _ in range(n_requests)]) + + assert all(res.status_code == 201 for res in responses) + messages = [res.json()["message"] for res in responses] + assert "Successfully assigned the 'superuser' role to user@example.com." in messages + + responses = await asyncio.gather(*[post() for _ in range(n_requests)]) + assert all(res.status_code == 201 for res in responses) + messages = [res.json()["message"] for res in responses] + assert all(msg == "User user@example.com already has the 'superuser' role." for msg in messages)