|
| 1 | +# pyright: reportMissingImports=false |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import time |
| 5 | + |
| 6 | +from dotenv import load_dotenv # type: ignore |
| 7 | +from mcp.server.auth.provider import AccessToken, OAuthToken |
| 8 | +from mcp.server.auth.providers.transparent_proxy import ( |
| 9 | + ProxySettings, # type: ignore |
| 10 | + TransparentOAuthProxyProvider, |
| 11 | + ProxyTokenHandler, |
| 12 | +) |
| 13 | +from mcp.server.auth.routes import cors_middleware, create_auth_routes |
| 14 | +from mcp.server.auth.settings import ClientRegistrationOptions |
| 15 | +from pydantic import AnyHttpUrl |
| 16 | +from starlette.applications import Starlette |
| 17 | +from starlette.requests import Request # type: ignore |
| 18 | +from starlette.responses import JSONResponse, Response |
| 19 | +from starlette.routing import Route |
| 20 | +from uvicorn import Config, Server |
| 21 | + |
| 22 | +# Load environment variables from .env if present |
| 23 | +load_dotenv() |
| 24 | + |
| 25 | +# Configure logging after .env so LOG_LEVEL can come from environment |
| 26 | +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() |
| 27 | + |
| 28 | +logging.basicConfig( |
| 29 | + level=LOG_LEVEL, |
| 30 | + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| 31 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 32 | +) |
| 33 | + |
| 34 | +# Dedicated logger for this server module |
| 35 | +logger = logging.getLogger("proxy_oauth.auth_server") |
| 36 | + |
| 37 | +# Suppress noisy INFO messages from the FastMCP low-level server unless we are |
| 38 | +# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type |
| 39 | +# ListToolsRequest") are helpful for debugging but clutter normal output. |
| 40 | + |
| 41 | +_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server") |
| 42 | +if LOG_LEVEL == "DEBUG": |
| 43 | + # In full debug mode, allow the library to emit its detailed logs |
| 44 | + _mcp_lowlevel_logger.setLevel(logging.DEBUG) |
| 45 | +else: |
| 46 | + # Otherwise, only warnings and above |
| 47 | + _mcp_lowlevel_logger.setLevel(logging.WARNING) |
| 48 | + |
| 49 | +# ---------------------------------------------------------------------------- |
| 50 | +# Environment configuration |
| 51 | +# ---------------------------------------------------------------------------- |
| 52 | +# Load and validate settings from the environment (uses .env automatically) |
| 53 | +settings = ProxySettings.load() |
| 54 | + |
| 55 | +# Upstream endpoints (fully-qualified URLs) |
| 56 | +UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize) |
| 57 | +UPSTREAM_TOKEN: str = str(settings.upstream_token) |
| 58 | +UPSTREAM_JWKS_URI = settings.jwks_uri |
| 59 | +# Derive base URL from the authorize endpoint for convenience / tests |
| 60 | +UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0] |
| 61 | + |
| 62 | +# Client credentials & defaults |
| 63 | +CLIENT_ID: str = settings.client_id or "demo-client-id" |
| 64 | +CLIENT_SECRET = settings.client_secret |
| 65 | +DEFAULT_SCOPE: str = settings.default_scope |
| 66 | + |
| 67 | +# Optional audience passthrough (not part of ProxySettings yet) |
| 68 | +AUDIENCE = os.getenv("PROXY_AUDIENCE") |
| 69 | + |
| 70 | +# Metadata URL (only used if we need to fetch from upstream) |
| 71 | +UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server" |
| 72 | + |
| 73 | +# --------------------------------------------------------------------------- |
| 74 | +# Logging helpers |
| 75 | +# --------------------------------------------------------------------------- |
| 76 | + |
| 77 | + |
| 78 | +def _mask_secret(secret: str | None) -> str | None: # noqa: D401 |
| 79 | + """Return a masked version of the given secret. |
| 80 | +
|
| 81 | + The first and last four characters are preserved (if available) and the |
| 82 | + middle section is replaced by asterisks. If the secret is shorter than |
| 83 | + eight characters, the entire value is replaced by ``*``. |
| 84 | + """ |
| 85 | + |
| 86 | + if not secret: |
| 87 | + return None |
| 88 | + |
| 89 | + if len(secret) <= 8: |
| 90 | + return "*" * len(secret) |
| 91 | + |
| 92 | + return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}" |
| 93 | + |
| 94 | + |
| 95 | +# Consolidated configuration (with sensitive data redacted) |
| 96 | +_masked_settings = settings.model_dump(exclude_none=True).copy() |
| 97 | + |
| 98 | +if "client_secret" in _masked_settings: |
| 99 | + _masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"]) |
| 100 | + |
| 101 | +# Log configuration at *debug* level only so it can be enabled when needed |
| 102 | +logger.debug("[Auth Proxy Config] %s", _masked_settings) |
| 103 | + |
| 104 | +# Server host/port |
| 105 | +AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000")) |
| 106 | +AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost") |
| 107 | +AUTH_SERVER_URL = os.getenv( |
| 108 | + "AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}" |
| 109 | +) |
| 110 | + |
| 111 | +# ---------------------------------------------------------------------------- |
| 112 | +# Auth Server |
| 113 | +# ---------------------------------------------------------------------------- |
| 114 | + |
| 115 | +# Create auth provider |
| 116 | +oauth_provider = TransparentOAuthProxyProvider(settings=settings) |
| 117 | + |
| 118 | +# Enable client registration |
| 119 | +client_registration_options = ClientRegistrationOptions( |
| 120 | + enabled=True, |
| 121 | + valid_scopes=["openid"], |
| 122 | + default_scopes=["openid"], |
| 123 | +) |
| 124 | + |
| 125 | +# Create auth routes |
| 126 | +routes = create_auth_routes( |
| 127 | + provider=oauth_provider, |
| 128 | + issuer_url=AnyHttpUrl(AUTH_SERVER_URL), |
| 129 | + service_documentation_url=None, |
| 130 | + client_registration_options=client_registration_options, |
| 131 | + revocation_options=None, |
| 132 | +) |
| 133 | + |
| 134 | +# Add token endpoint handler |
| 135 | +# We need to replace any existing token endpoint route |
| 136 | +routes = [r for r in routes if not (hasattr(r, "path") and r.path == "/token")] |
| 137 | + |
| 138 | +# Create token handler and add it to routes |
| 139 | +proxy_token_handler = ProxyTokenHandler(oauth_provider) |
| 140 | +routes.append(Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"])) |
| 141 | + |
| 142 | +# Add token introspection endpoint for Resource Servers |
| 143 | +async def introspect_handler(request: Request) -> Response: |
| 144 | + """ |
| 145 | + Token introspection endpoint for Resource Servers. |
| 146 | +
|
| 147 | + Resource Servers call this endpoint to validate tokens without |
| 148 | + needing direct access to token storage. |
| 149 | + """ |
| 150 | + form = await request.form() |
| 151 | + token = form.get("token") |
| 152 | + if not token or not isinstance(token, str): |
| 153 | + return JSONResponse({"active": False}, status_code=400) |
| 154 | + |
| 155 | + # For the transparent proxy, we don't actually validate tokens |
| 156 | + # Just create a dummy AccessToken like the provider does |
| 157 | + access_token = AccessToken( |
| 158 | + token=token, client_id=str(CLIENT_ID), scopes=[DEFAULT_SCOPE], expires_at=None |
| 159 | + ) |
| 160 | + |
| 161 | + return JSONResponse( |
| 162 | + { |
| 163 | + "active": True, |
| 164 | + "client_id": access_token.client_id, |
| 165 | + "scope": " ".join(access_token.scopes), |
| 166 | + "exp": access_token.expires_at, |
| 167 | + "iat": int(time.time()), |
| 168 | + "token_type": "Bearer", |
| 169 | + "aud": access_token.resource, # RFC 8707 audience claim |
| 170 | + } |
| 171 | + ) |
| 172 | + |
| 173 | + |
| 174 | +routes.append( |
| 175 | + Route( |
| 176 | + "/introspect", |
| 177 | + endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]), |
| 178 | + methods=["POST", "OPTIONS"], |
| 179 | + ) |
| 180 | +) |
| 181 | + |
| 182 | +# Create Starlette app with routes |
| 183 | +auth_app = Starlette(routes=routes) |
| 184 | + |
| 185 | + |
| 186 | +async def run_server(): |
| 187 | + """Run the Authorization Server.""" |
| 188 | + config = Config( |
| 189 | + auth_app, |
| 190 | + host=AUTH_SERVER_HOST, |
| 191 | + port=AUTH_SERVER_PORT, |
| 192 | + log_level="info", |
| 193 | + ) |
| 194 | + server = Server(config) |
| 195 | + |
| 196 | + logger.info(f"🚀 MCP Authorization Server running on {AUTH_SERVER_URL}") |
| 197 | + |
| 198 | + await server.serve() |
| 199 | + |
| 200 | + |
| 201 | +if __name__ == "__main__": |
| 202 | + import asyncio |
| 203 | + |
| 204 | + asyncio.run(run_server()) |
0 commit comments