diff --git a/datajunction-query/alembic/env.py b/datajunction-query/alembic/env.py index ad293892e..ba9b881db 100644 --- a/datajunction-query/alembic/env.py +++ b/datajunction-query/alembic/env.py @@ -67,7 +67,10 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = create_engine(settings.index) + connectable = create_engine( + settings.index, + connect_args={"options": f"-csearch_path={settings.customSchema}"}, + ) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/datajunction-query/djqs/config.py b/datajunction-query/djqs/config.py index e68685d06..cc12ac999 100644 --- a/datajunction-query/djqs/config.py +++ b/datajunction-query/djqs/config.py @@ -30,6 +30,8 @@ class Settings(BaseSettings): # pylint: disable=too-few-public-methods # SQLAlchemy URI for the metadata database. index: str = "sqlite:///djqs.db?check_same_thread=False" + customSchema: str = "public" + # The default engine to use for reflection default_reflection_engine: str = "default" diff --git a/datajunction-query/djqs/utils.py b/datajunction-query/djqs/utils.py index 69a219730..dc5b5086f 100644 --- a/datajunction-query/djqs/utils.py +++ b/datajunction-query/djqs/utils.py @@ -7,13 +7,14 @@ import logging import os from functools import lru_cache -from typing import Iterator +from typing import Iterator, Optional from dotenv import load_dotenv from pydantic.datetime_parse import parse_datetime from rich.logging import RichHandler from sqlalchemy.engine import Engine from sqlmodel import Session, create_engine +from starlette.requests import Request from djqs.config import Settings @@ -56,11 +57,18 @@ def get_metadata_engine() -> Engine: return engine -def get_session() -> Iterator[Session]: +def get_session(request: Request = None) -> Iterator[Session]: """ Per-request session. """ + schema = request.headers.get("tenant") engine = get_metadata_engine() + settings = get_settings() + + if schema: + engine = engine.execution_options(schema_translate_map={None: schema}) + + settings.customSchema = request.headers.get("new_tenant") with Session(engine, autoflush=False) as session: # pragma: no cover yield session diff --git a/datajunction-server/alembic/env.py b/datajunction-server/alembic/env.py index 50ffddfb6..c23bba9b8 100644 --- a/datajunction-server/alembic/env.py +++ b/datajunction-server/alembic/env.py @@ -85,7 +85,12 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = create_engine(settings.index) + settings = get_settings() + + connectable = create_engine( + settings.index, + connect_args={"options": f"-csearch_path={settings.customSchema}"}, + ) with connectable.connect() as connection: context.configure( diff --git a/datajunction-server/datajunction_server/api/access/authentication/basic.py b/datajunction-server/datajunction_server/api/access/authentication/basic.py index 1596b3de8..1b1a32140 100644 --- a/datajunction-server/datajunction_server/api/access/authentication/basic.py +++ b/datajunction-server/datajunction_server/api/access/authentication/basic.py @@ -4,7 +4,7 @@ from datetime import timedelta from http import HTTPStatus -from fastapi import APIRouter, Depends, Form +from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, Response from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy import select @@ -25,14 +25,19 @@ @router.post("/basic/user/") async def create_a_user( - email: str = Form(), - username: str = Form(), - password: str = Form(), + request: Request, session: AsyncSession = Depends(get_session), ) -> JSONResponse: """ Create a new user """ + body = await request.body() + if not body: + return JSONResponse(content={"error": "Request body is empty"}, status_code=400) + data = await request.json() + username = data.get("username") + email = data.get("email") + password = data.get("password") user_result = await session.execute(select(User).where(User.username == username)) if user_result.scalar_one_or_none(): raise DJException( diff --git a/datajunction-server/datajunction_server/config.py b/datajunction-server/datajunction_server/config.py index 69425f400..659237c1d 100644 --- a/datajunction-server/datajunction_server/config.py +++ b/datajunction-server/datajunction_server/config.py @@ -33,6 +33,8 @@ class Settings( # SQLAlchemy URI for the metadata database. index: str = "postgresql+psycopg://dj:dj@postgres_metadata:5432/dj" + customSchema: str = "public" + # Directory where the repository lives. This should have 2 subdirectories, "nodes" and # "databases". repository: Path = Path(".") diff --git a/datajunction-server/datajunction_server/utils.py b/datajunction-server/datajunction_server/utils.py index 4681e1552..ff5b411c0 100644 --- a/datajunction-server/datajunction_server/utils.py +++ b/datajunction-server/datajunction_server/utils.py @@ -14,7 +14,7 @@ from dotenv import load_dotenv from fastapi import Depends from rich.logging import RichHandler -from sqlalchemy import AsyncAdaptedQueuePool +from sqlalchemy import AsyncAdaptedQueuePool, text from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -70,6 +70,7 @@ def __init__(self): self.engine: AsyncEngine | None = None self.session_maker = None self.session = None + self.schema = None def init_db(self): """ @@ -94,6 +95,11 @@ def init_db(self): }, ) + if self.schema: + self.engine = self.engine.execution_options( + schema_translate_map={None: self.schema} + ) + async_session_factory = async_sessionmaker( bind=self.engine, autocommit=False, @@ -115,17 +121,20 @@ async def close(self): @lru_cache(maxsize=None) -def get_session_manager() -> DatabaseSessionManager: +def get_session_manager(request: Optional[Request] = None) -> DatabaseSessionManager: """ Get session manager """ session_manager = DatabaseSessionManager() + session_manager.schema = request.headers.get("tenant") + settings = get_settings() + settings.customSchema = request.headers.get("new_tenant") session_manager.init_db() return session_manager @lru_cache(maxsize=None) -def get_engine() -> AsyncEngine: +def get_engine(schema: str) -> AsyncEngine: """ Create the metadata engine. """ @@ -143,14 +152,16 @@ def get_engine() -> AsyncEngine: "connect_timeout": settings.db_connect_timeout, }, ) + if schema: + engine = engine.execution_options(schema_translate_map={None: schema}) return engine -async def get_session() -> AsyncIterator[AsyncSession]: +async def get_session(request: Request = None) -> AsyncIterator[AsyncSession]: """ Async database session. """ - session_manager = get_session_manager() + session_manager = get_session_manager(request) session = session_manager.session() try: yield session