diff --git a/examples/servers/proxy-auth/README.md b/examples/servers/proxy-auth/README.md new file mode 100644 index 000000000..72734cabe --- /dev/null +++ b/examples/servers/proxy-auth/README.md @@ -0,0 +1,101 @@ +# OAuth Proxy Server + +This is a minimal OAuth proxy server example for the MCP Python SDK that demonstrates how to create a transparent OAuth proxy for existing OAuth providers. + +## Installation + +```bash +# Navigate to the proxy-auth directory +cd examples/servers/proxy-auth + +# Install the package in development mode +uv add -e . +``` + +## Configuration + +The servers can be configured using either: + +1. **Command-line arguments** (take precedence when provided) +2. **Environment variables** (loaded from `.env` file when present) + +Example `.env` file: + +```env +# Auth Server Configuration +AUTH_SERVER_HOST=localhost +AUTH_SERVER_PORT=9000 +AUTH_SERVER_URL=http://localhost:9000 + +# Resource Server Configuration +RESOURCE_SERVER_HOST=localhost +RESOURCE_SERVER_PORT=8001 +RESOURCE_SERVER_URL=http://localhost:8001 + +# Combo Server Configuration +COMBO_SERVER_HOST=localhost +COMBO_SERVER_PORT=8000 + +# OAuth Provider Configuration +UPSTREAM_AUTHORIZE=https://github.com/login/oauth/authorize +UPSTREAM_TOKEN=https://github.com/login/oauth/access_token +CLIENT_ID=your-client-id +CLIENT_SECRET=your-client-secret +DEFAULT_SCOPE=openid +``` + +## Running the Servers + +The example consists of three server components that can be run using the project scripts defined in pyproject.toml: + +### Step 1: Start Authorization Server + +```bash +# Start Authorization Server on port 9000 +uv run mcp-proxy-auth-as --port=9000 + +# Or rely on environment variables from .env file +uv run mcp-proxy-auth-as +``` + +**What it provides:** + +- OAuth 2.0 flows (authorization, token exchange) +- Token introspection endpoint for Resource Servers (`/introspect`) +- Client registration endpoint (`/register`) + +### Step 2: Start Resource Server (MCP Server) + +```bash +# In another terminal, start Resource Server on port 8001 +uv run mcp-proxy-auth-rs --port=8001 --auth-server=http://localhost:9000 --transport=streamable-http + +# Or rely on environment variables from .env file +uv run mcp-proxy-auth-rs +``` + +### Step 3: Alternatively, Run Combined Server + +For simpler testing, you can run a combined proxy server that handles both authentication and resource access: + +```bash +# Run the combined proxy server on port 8000 +uv run mcp-proxy-auth-combo --port=8000 --transport=streamable-http + +# Or rely on environment variables from .env file +uv run mcp-proxy-auth-combo +``` + +## How It Works + +The proxy OAuth server acts as a transparent proxy between: + +1. Client applications requesting OAuth tokens +2. Upstream OAuth providers (like GitHub, Google, etc.) + +This allows MCP servers to leverage existing OAuth providers without implementing their own authentication systems. + +The server code is organized in the `proxy_auth` package for better modularity. + +```text +``` diff --git a/examples/servers/proxy-auth/proxy_auth/__init__.py b/examples/servers/proxy-auth/proxy_auth/__init__.py new file mode 100644 index 000000000..a9caffbf1 --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/__init__.py @@ -0,0 +1,25 @@ +"""OAuth Proxy Server for MCP.""" + +__version__ = "0.1.0" + +# Import key components for easier access +from .auth_server import auth_server as auth_server +from .auth_server import main as auth_server_main +from .combo_server import combo_server as combo_server +from .combo_server import main as combo_server_main +from .resource_server import main as resource_server_main +from .resource_server import resource_server as resource_server +from .token_verifier import IntrospectionTokenVerifier + +__all__ = [ + "auth_server", + "resource_server", + "combo_server", + "IntrospectionTokenVerifier", + "auth_server_main", + "resource_server_main", + "combo_server_main", +] + +# Aliases for the script entry points +main = combo_server_main diff --git a/examples/servers/proxy-auth/proxy_auth/__main__.py b/examples/servers/proxy-auth/proxy_auth/__main__.py new file mode 100644 index 000000000..71ce1015e --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/__main__.py @@ -0,0 +1,7 @@ +"""Main entry point for Combo Proxy OAuth Resource+Auth MCP server.""" + +import sys + +from .combo_server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/proxy-auth/proxy_auth/auth_server.py b/examples/servers/proxy-auth/proxy_auth/auth_server.py new file mode 100644 index 000000000..58c35bd8d --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/auth_server.py @@ -0,0 +1,158 @@ +# pyright: reportMissingImports=false +import argparse +import logging +import os + +from dotenv import load_dotenv # type: ignore +from mcp.server.auth.providers.transparent_proxy import ( + ProxySettings, # type: ignore + TransparentOAuthProxyProvider, +) +from mcp.server.auth.settings import AuthSettings +from mcp.server.fastmcp.server import FastMCP +from pydantic import AnyHttpUrl + +# Load environment variables from .env if present +load_dotenv() + +# Configure logging after .env so LOG_LEVEL can come from environment +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +# Dedicated logger for this server module +logger = logging.getLogger("proxy_auth.auth_server") + +# Suppress noisy INFO messages from the FastMCP low-level server unless we are +# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type +# ListToolsRequest") are helpful for debugging but clutter normal output. + +_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server") +if LOG_LEVEL == "DEBUG": + # In full debug mode, allow the library to emit its detailed logs + _mcp_lowlevel_logger.setLevel(logging.DEBUG) +else: + # Otherwise, only warnings and above + _mcp_lowlevel_logger.setLevel(logging.WARNING) + +# ---------------------------------------------------------------------------- +# Environment configuration +# ---------------------------------------------------------------------------- +# Load and validate settings from the environment (uses .env automatically) +settings = ProxySettings.load() + +# Upstream endpoints (fully-qualified URLs) +UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize) +UPSTREAM_TOKEN: str = str(settings.upstream_token) +UPSTREAM_JWKS_URI = settings.jwks_uri +# Derive base URL from the authorize endpoint for convenience / tests +UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0] + +# Client credentials & defaults +CLIENT_ID: str = settings.client_id or "demo-client-id" +CLIENT_SECRET = settings.client_secret +DEFAULT_SCOPE: str = settings.default_scope + +# Metadata URL (only used if we need to fetch from upstream) +UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server" + +## Load and validate settings from the environment (uses .env automatically) +settings = ProxySettings.load() + +# Server host/port +RESOURCE_SERVER_PORT = int(os.getenv("RESOURCE_SERVER_PORT", "8000")) +RESOURCE_SERVER_HOST = os.getenv("RESOURCE_SERVER_HOST", "localhost") +RESOURCE_SERVER_URL = os.getenv( + "RESOURCE_SERVER_URL", f"http://{RESOURCE_SERVER_HOST}:{RESOURCE_SERVER_PORT}" +) + +# Auth server configuration +AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000")) +AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost") +AUTH_SERVER_URL = os.getenv( + "AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}" +) + +auth_settings = AuthSettings( + issuer_url=AnyHttpUrl(AUTH_SERVER_URL), + resource_server_url=AnyHttpUrl(RESOURCE_SERVER_URL), + required_scopes=["openid"], +) + +# Create the OAuth provider with our settings +oauth_provider = TransparentOAuthProxyProvider( + settings=settings, auth_settings=auth_settings +) + + +# ---------------------------------------------------------------------------- +# Auth Server using FastMCP +# ---------------------------------------------------------------------------- +def create_auth_server( + host: str = AUTH_SERVER_HOST, + port: int = AUTH_SERVER_PORT, + auth_settings: AuthSettings = auth_settings, + oauth_provider: TransparentOAuthProxyProvider = oauth_provider, +): + """Create a auth server instance with the given configuration.""" + + # Create FastMCP resource server instance + auth_server = FastMCP( + name="Auth Server", + host=host, + port=port, + auth_server_provider=oauth_provider, + auth=auth_settings, + ) + + return auth_server + + +# Create a default server instance +auth_server = create_auth_server() + + +def main(): + """Command-line entry point for the Authorization Server.""" + parser = argparse.ArgumentParser(description="MCP OAuth Proxy Authorization Server") + parser.add_argument( + "--host", + default=None, + help="Host to bind to (overrides AUTH_SERVER_HOST env var)", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Port to bind to (overrides AUTH_SERVER_PORT env var)", + ) + parser.add_argument( + "--transport", + default="streamable-http", + help="Transport type (streamable-http or websocket)", + ) + + args = parser.parse_args() + + # Use command-line arguments only if provided, otherwise use environment variables + host = args.host or AUTH_SERVER_HOST + port = args.port or AUTH_SERVER_PORT + + # Log the configuration being used + logger.info(f"Starting Authorization Server with host={host}, port={port}") + + # Create a server with the specified configuration + auth_server = create_auth_server( + host=host, port=port, auth_settings=auth_settings, oauth_provider=oauth_provider + ) + + logger.info(f"🚀 MCP OAuth Authorization Server running on http://{host}:{port}") + auth_server.run(transport=args.transport) + + +if __name__ == "__main__": + main() diff --git a/examples/servers/proxy-auth/proxy_auth/combo_server.py b/examples/servers/proxy-auth/proxy_auth/combo_server.py new file mode 100644 index 000000000..0ca45e80b --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/combo_server.py @@ -0,0 +1,295 @@ +# pyright: reportMissingImports=false +import argparse +import base64 +import json +import logging +import os +import time +from typing import Any + +from dotenv import load_dotenv # type: ignore +from mcp.server.auth.providers.transparent_proxy import ( + ProxySettings, # type: ignore + TransparentOAuthProxyProvider, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.server import Context, FastMCP +from pydantic import AnyHttpUrl +from starlette.requests import Request # type: ignore + +# Load environment variables from .env if present +load_dotenv() + +# Configure logging after .env so LOG_LEVEL can come from environment +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +# Dedicated logger for this server module +logger = logging.getLogger("proxy_auth.combo_server") + +# Suppress noisy INFO messages from the FastMCP low-level server unless we are +# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type +# ListToolsRequest") are helpful for debugging but clutter normal output. + +_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server") +if LOG_LEVEL == "DEBUG": + # In full debug mode, allow the library to emit its detailed logs + _mcp_lowlevel_logger.setLevel(logging.DEBUG) +else: + # Otherwise, only warnings and above + _mcp_lowlevel_logger.setLevel(logging.WARNING) + +# ---------------------------------------------------------------------------- +# Environment configuration +# ---------------------------------------------------------------------------- +# Load and validate settings from the environment (uses .env automatically) +settings = ProxySettings.load() + +# Upstream endpoints (fully-qualified URLs) +UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize) +UPSTREAM_TOKEN: str = str(settings.upstream_token) +UPSTREAM_JWKS_URI = settings.jwks_uri +# Derive base URL from the authorize endpoint for convenience / tests +UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0] + +# Client credentials & defaults +CLIENT_ID: str = settings.client_id or "demo-client-id" +CLIENT_SECRET = settings.client_secret +DEFAULT_SCOPE: str = settings.default_scope + +# Metadata URL (only used if we need to fetch from upstream) +UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server" + +# --------------------------------------------------------------------------- +# Logging helpers +# --------------------------------------------------------------------------- + + +def _mask_secret(secret: str | None) -> str | None: # noqa: D401 + """Return a masked version of the given secret. + + The first and last four characters are preserved (if available) and the + middle section is replaced by asterisks. If the secret is shorter than + eight characters, the entire value is replaced by ``*``. + """ + + if not secret: + return None + + if len(secret) <= 8: + return "*" * len(secret) + + return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}" + + +# Consolidated configuration (with sensitive data redacted) +_masked_settings = settings.model_dump(exclude_none=True).copy() + +if "client_secret" in _masked_settings: + _masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"]) + +# Log configuration at *debug* level only so it can be enabled when needed +logger.debug("[Proxy Config] %s", _masked_settings) + +# Server host/port +COMBO_SERVER_PORT = int(os.getenv("COMBO_SERVER_PORT", os.getenv("PROXY_PORT", "8000"))) +COMBO_SERVER_HOST = os.getenv("COMBO_SERVER_HOST", os.getenv("PROXY_HOST", "localhost")) +# Infer PROXY_ISSUER_URL from COMBO_SERVER_HOST and COMBO_SERVER_PORT +# if not explicitly set +PROXY_ISSUER_URL = ( + os.getenv("PROXY_ISSUER_URL") or f"http://{COMBO_SERVER_HOST}:{COMBO_SERVER_PORT}" +) + +# ---------------------------------------------------------------------------- +# FastMCP server (now created via library helper) +# ---------------------------------------------------------------------------- +auth_settings = AuthSettings( + issuer_url=AnyHttpUrl(PROXY_ISSUER_URL), # type: ignore[arg-type] + resource_server_url=AnyHttpUrl(PROXY_ISSUER_URL), # type: ignore[arg-type] + required_scopes=["openid"], + client_registration_options=ClientRegistrationOptions(enabled=True), +) + + +def create_combo_server(host: str = COMBO_SERVER_HOST, port: int = COMBO_SERVER_PORT): + """Create a combined proxy server instance with the given configuration.""" + + # Create the OAuth provider with our settings + oauth_provider = TransparentOAuthProxyProvider( + settings=settings, auth_settings=auth_settings + ) + + # Create FastMCP instance with the provider + server = FastMCP( + name="Transparent OAuth Proxy", + host=host, + port=port, + auth_server_provider=oauth_provider, + auth=auth_settings, + ) + + # Add demo tools + @server.tool() + def echo(message: str) -> str: + return f"Echo: {message}" + + @server.tool() + async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]: + """ + Get information about the authenticated user. + + This tool demonstrates accessing user information from the OAuth access token. + The user must be authenticated via OAuth to access this tool. + + Returns: + Dictionary containing user information from the access token + """ + from mcp.server.auth.middleware.auth_context import get_access_token + + # Get the access token from the authentication context + access_token = get_access_token() + + if not access_token: + return { + "error": "No access token found - user not authenticated", + "authenticated": False, + } + + # Attempt to decode the access token as JWT to extract useful user claims. + # Many OAuth providers issue JWT access tokens (or ID tokens) that contain + # the user's subject (sub) and preferred username. We parse the token + # *without* signature verification – we only need the public claims for + # display purposes. If the token is opaque or the decode fails, we simply + # skip this step. + + def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401 + """Best-effort JWT decode without verification. + + Returns the payload dictionary if the token *looks* like a JWT and can + be base64-decoded. If anything fails we return None. + """ + + try: + parts = token_str.split(".") + if len(parts) != 3: + return None # Not a JWT + + # JWT parts are URL-safe base64 without padding + def _b64decode(segment: str) -> bytes: + padding = "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(segment + padding) + + payload_bytes = _b64decode(parts[1]) + return json.loads(payload_bytes) + except Exception: # noqa: BLE001 + return None + + jwt_claims = _try_decode_jwt(access_token.token) + + # Build response with token information plus any extracted claims + response: dict[str, Any] = { + "authenticated": True, + "client_id": access_token.client_id, + "scopes": access_token.scopes, + "token_type": "Bearer", + "expires_at": access_token.expires_at, + "resource": access_token.resource, + } + + if jwt_claims: + # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if + # absent. + uid = jwt_claims.get("userid") or jwt_claims.get("sub") + if uid is not None: + response["userid"] = uid # camelCase variant used in FastMCP reference + response["user_id"] = uid # snake_case variant + response["username"] = ( + jwt_claims.get("preferred_username") + or jwt_claims.get("nickname") + or jwt_claims.get("name") + ) + response["issuer"] = jwt_claims.get("iss") + response["audience"] = jwt_claims.get("aud") + response["issued_at"] = jwt_claims.get("iat") + + # Calculate expiration helpers + if access_token.expires_at: + response["expires_at_iso"] = time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(access_token.expires_at) + ) + response["expires_in_seconds"] = max( + 0, access_token.expires_at - int(time.time()) + ) + + return response + + @server.tool() + async def test_endpoint( + message: str = "Hello from proxy server!", + ) -> dict[str, Any]: + """ + Test endpoint for debugging OAuth proxy functionality. + + Args: + message: Optional message to echo back + + Returns: + Test response with server information + """ + return { + "message": message, + "server": "Transparent OAuth Proxy Server", + "status": "active", + "oauth_configured": True, + } + + return server + + +# Create a default server instance +combo_server = create_combo_server() + + +def main(): + """Command-line entry point for the Combo Server.""" + parser = argparse.ArgumentParser(description="MCP OAuth Proxy Combo Server") + parser.add_argument( + "--host", + default=None, + help="Host to bind to (overrides COMBO_SERVER_HOST env var)", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Port to bind to (overrides COMBO_SERVER_PORT env var)", + ) + parser.add_argument( + "--transport", + default="streamable-http", + help="Transport type (streamable-http or websocket)", + ) + + args = parser.parse_args() + + # Use command-line arguments only if provided, otherwise use environment variables + host = args.host or COMBO_SERVER_HOST + port = args.port or COMBO_SERVER_PORT + + # Log the configuration being used + logger.info(f"Starting Combo Server with host={host}, port={port}") + + # Create a server with the specified configuration + combo_server = create_combo_server(host=host, port=port) + + logger.info(f"🚀 MCP OAuth Proxy Combo Server running on http://{host}:{port}") + combo_server.run(transport=args.transport) + + +if __name__ == "__main__": + main() diff --git a/examples/servers/proxy-auth/proxy_auth/resource_server.py b/examples/servers/proxy-auth/proxy_auth/resource_server.py new file mode 100644 index 000000000..c8e2e7fc7 --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/resource_server.py @@ -0,0 +1,327 @@ +# pyright: reportMissingImports=false +import argparse +import base64 +import json +import logging +import os +import time +from typing import Any + +from dotenv import load_dotenv # type: ignore +from mcp.server.auth.providers.transparent_proxy import ProxySettings # type: ignore +from mcp.server.auth.settings import AuthSettings +from mcp.server.fastmcp.server import Context, FastMCP +from pydantic import AnyHttpUrl +from starlette.requests import Request # type: ignore + +from .token_verifier import IntrospectionTokenVerifier + +# Load environment variables from .env if present +load_dotenv() + +# Configure logging after .env so LOG_LEVEL can come from environment +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +# Dedicated logger for this server module +logger = logging.getLogger("proxy_oauth.server") + +# Suppress noisy INFO messages from the FastMCP low-level server unless we are +# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type +# ListToolsRequest") are helpful for debugging but clutter normal output. + +_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server") +if LOG_LEVEL == "DEBUG": + # In full debug mode, allow the library to emit its detailed logs + _mcp_lowlevel_logger.setLevel(logging.DEBUG) +else: + # Otherwise, only warnings and above + _mcp_lowlevel_logger.setLevel(logging.WARNING) + +# ---------------------------------------------------------------------------- +# Environment configuration +# ---------------------------------------------------------------------------- +# Load and validate settings from the environment (uses .env automatically) +settings = ProxySettings.load() + +# Upstream endpoints (fully-qualified URLs) +UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize) +UPSTREAM_TOKEN: str = str(settings.upstream_token) +UPSTREAM_JWKS_URI = settings.jwks_uri +# Derive base URL from the authorize endpoint for convenience / tests +UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0] + +# Client credentials & defaults +CLIENT_ID: str = settings.client_id or "demo-client-id" +CLIENT_SECRET = settings.client_secret +DEFAULT_SCOPE: str = settings.default_scope + +# Optional audience passthrough (not part of ProxySettings yet) +AUDIENCE = os.getenv("PROXY_AUDIENCE") + +# Metadata URL (only used if we need to fetch from upstream) +UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server" + +# Server host/port +RESOURCE_SERVER_PORT = int(os.getenv("RESOURCE_SERVER_PORT", "8000")) +RESOURCE_SERVER_HOST = os.getenv("RESOURCE_SERVER_HOST", "localhost") +RESOURCE_SERVER_URL = os.getenv( + "RESOURCE_SERVER_URL", f"http://{RESOURCE_SERVER_HOST}:{RESOURCE_SERVER_PORT}" +) + +# Auth server configuration +AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000")) +AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost") +AUTH_SERVER_URL = os.getenv( + "AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}" +) + +# Create token verifier that uses the auth server's introspection endpoint +token_verifier_instance = IntrospectionTokenVerifier( + introspection_endpoint=f"{AUTH_SERVER_URL}/introspect", + server_url=RESOURCE_SERVER_URL, + validate_resource=False, # Don't validate resource for this example +) + +auth_settings = AuthSettings( + issuer_url=AnyHttpUrl(AUTH_SERVER_URL), + resource_server_url=AnyHttpUrl(RESOURCE_SERVER_URL), + required_scopes=["openid"], +) + +# --------------------------------------------------------------------------- +# Logging helpers +# --------------------------------------------------------------------------- + + +def _mask_secret(secret: str | None) -> str | None: # noqa: D401 + """Return a masked version of the given secret. + + The first and last four characters are preserved (if available) and the + middle section is replaced by asterisks. If the secret is shorter than + eight characters, the entire value is replaced by ``*``. + """ + + if not secret: + return None + + if len(secret) <= 8: + return "*" * len(secret) + + return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}" + + +# Consolidated configuration (with sensitive data redacted) +_masked_settings = settings.model_dump(exclude_none=True).copy() + +if "client_secret" in _masked_settings: + _masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"]) + +# Log configuration at *debug* level only so it can be enabled when needed +logger.debug("[Proxy Config] %s", _masked_settings) + +# ---------------------------------------------------------------------------- +# FastMCP server (now created via library helper) +# ---------------------------------------------------------------------------- + + +def create_resource_server( + host: str = RESOURCE_SERVER_HOST, + port: int = RESOURCE_SERVER_PORT, + auth_settings: AuthSettings = auth_settings, + token_verifier_instance: IntrospectionTokenVerifier = token_verifier_instance, +): + """Create a auth server instance with the given configuration.""" + + # Create FastMCP resource server instance + resource_server = FastMCP( + name="MCP Resource Server", + host=RESOURCE_SERVER_HOST, + port=RESOURCE_SERVER_PORT, + auth=auth_settings, + token_verifier=token_verifier_instance, + ) + + # --------------------------------------------------------------------------- + # Demo tools + # --------------------------------------------------------------------------- + + @resource_server.tool() + def echo(message: str) -> str: + return f"Echo: {message}" + + @resource_server.tool() + async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]: + """ + Get information about the authenticated user. + + This tool demonstrates accessing user information from the OAuth access token. + The user must be authenticated via OAuth to access this tool. + + Returns: + Dictionary containing user information from the access token + """ + from mcp.server.auth.middleware.auth_context import get_access_token + + # Get the access token from the authentication context + access_token = get_access_token() + + if not access_token: + return { + "error": "No access token found - user not authenticated", + "authenticated": False, + } + + # Attempt to decode the access token as JWT to extract useful user claims. + # Many OAuth providers issue JWT access tokens (or ID tokens) that contain + # the user's subject (sub) and preferred username. We parse the token + # *without* signature verification – we only need the public claims for + # display purposes. If the token is opaque or the decode fails, we simply + # skip this step. + + def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401 + """Best-effort JWT decode without verification. + + Returns the payload dictionary if the token *looks* like a JWT and can + be base64-decoded. If anything fails we return None. + """ + + try: + parts = token_str.split(".") + if len(parts) != 3: + return None # Not a JWT + + # JWT parts are URL-safe base64 without padding + def _b64decode(segment: str) -> bytes: + padding = "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(segment + padding) + + payload_bytes = _b64decode(parts[1]) + return json.loads(payload_bytes) + except Exception: # noqa: BLE001 + return None + + jwt_claims = _try_decode_jwt(access_token.token) + + # Build response with token information plus any extracted claims + response: dict[str, Any] = { + "authenticated": True, + "client_id": access_token.client_id, + "scopes": access_token.scopes, + "token_type": "Bearer", + "expires_at": access_token.expires_at, + "resource": access_token.resource, + } + + if jwt_claims: + # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if + # absent. + uid = jwt_claims.get("userid") or jwt_claims.get("sub") + if uid is not None: + response["userid"] = uid # camelCase variant used in FastMCP reference + response["user_id"] = uid # snake_case variant + response["username"] = ( + jwt_claims.get("preferred_username") + or jwt_claims.get("nickname") + or jwt_claims.get("name") + ) + response["issuer"] = jwt_claims.get("iss") + response["audience"] = jwt_claims.get("aud") + response["issued_at"] = jwt_claims.get("iat") + + # Calculate expiration helpers + if access_token.expires_at: + response["expires_at_iso"] = time.strftime( + "%Y-%m-%dT%H:%M:%S", time.localtime(access_token.expires_at) + ) + response["expires_in_seconds"] = max( + 0, access_token.expires_at - int(time.time()) + ) + + return response + + @resource_server.tool() + async def test_endpoint( + message: str = "Hello from proxy server!", + ) -> dict[str, Any]: + """ + Test endpoint for debugging OAuth proxy functionality. + + Args: + message: Optional message to echo back + + Returns: + Test response with server information + """ + return { + "message": message, + "server": "Transparent OAuth Proxy Resource Server", + "status": "active", + "oauth_configured": True, + } + + return resource_server + + +# Create a default server instance +resource_server = create_resource_server() + + +def main(): + """Command-line entry point for the Resource Server.""" + parser = argparse.ArgumentParser(description="MCP OAuth Proxy Resource Server") + parser.add_argument( + "--host", + default=None, + help="Host to bind to (overrides RESOURCE_SERVER_HOST env var)", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Port to bind to (overrides RESOURCE_SERVER_PORT env var)", + ) + parser.add_argument( + "--auth-server", + default=None, + help="URL of the authorization server (overrides AUTH_SERVER_URL env var)", + ) + parser.add_argument( + "--transport", + default="streamable-http", + help="Transport type (streamable-http or websocket)", + ) + + args = parser.parse_args() + + # Use command-line arguments only if provided, otherwise use environment variables + host = args.host or RESOURCE_SERVER_HOST + port = args.port or RESOURCE_SERVER_PORT + auth_server = args.auth_server or AUTH_SERVER_URL + + # Log the configuration being used + logger.info( + f"Starting Resource Server with host={host}, port={port}, " + f"auth_server={auth_server}" + ) + logger.info("Using environment variables from .env file if present") + + # Create a server with the specified configuration + resource_server = create_resource_server( + host=host, + port=port, + auth_settings=auth_settings, + token_verifier_instance=token_verifier_instance, + ) + + logger.info(f"🚀 MCP Resource Server running on http://{host}:{port}") + resource_server.run(transport=args.transport) + + +if __name__ == "__main__": + resource_server.run(transport="streamable-http") diff --git a/examples/servers/proxy-auth/proxy_auth/token_verifier.py b/examples/servers/proxy-auth/proxy_auth/token_verifier.py new file mode 100644 index 000000000..3574e041c --- /dev/null +++ b/examples/servers/proxy-auth/proxy_auth/token_verifier.py @@ -0,0 +1,118 @@ +"""Example token verifier implementation using OAuth 2.0 Token Introspection.""" + +import logging +from typing import Any + +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url + +logger = logging.getLogger(__name__) + + +class IntrospectionTokenVerifier(TokenVerifier): + """Example token verifier that uses OAuth 2.0 Token Introspection (RFC 7662). + + This is a simple example implementation for demonstration purposes. + Production implementations should consider: + - Connection pooling and reuse + - More sophisticated error handling + - Rate limiting and retry logic + - Comprehensive configuration options + """ + + def __init__( + self, + introspection_endpoint: str, + server_url: str, + validate_resource: bool = False, + ): + self.introspection_endpoint = introspection_endpoint + self.server_url = server_url + self.validate_resource = validate_resource + self.resource_url = resource_url_from_server_url(server_url) + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify token via introspection endpoint.""" + import httpx + + # Validate URL to prevent SSRF attacks + if not self.introspection_endpoint.startswith( + ("https://", "http://localhost", "http://127.0.0.1") + ): + logger.warning( + f"Rejecting introspection endpoint with unsafe scheme: " + f"{self.introspection_endpoint}" + ) + return None + + # Configure secure HTTP client + timeout = httpx.Timeout(10.0, connect=5.0) + limits = httpx.Limits(max_connections=10, max_keepalive_connections=5) + + async with httpx.AsyncClient( + timeout=timeout, + limits=limits, + verify=True, # Enforce SSL verification + ) as client: + try: + response = await client.post( + self.introspection_endpoint, + data={"token": token}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code != 200: + logger.debug( + f"Token introspection returned status {response.status_code}" + ) + return None + + data = response.json() + if not data.get("active", False): + return None + + # RFC 8707 resource validation (only when --oauth-strict is set) + if self.validate_resource and not self._validate_resource(data): + logger.warning( + f"Token resource validation failed. Expected: " + f"{self.resource_url}" + ) + return None + + return AccessToken( + token=token, + client_id=data.get("client_id", "unknown"), + scopes=data.get("scope", "").split() if data.get("scope") else [], + expires_at=data.get("exp"), + resource=data.get("aud"), # Include resource in token + ) + except Exception as e: + logger.warning(f"Token introspection failed: {e}") + return None + + def _validate_resource(self, token_data: dict[str, Any]) -> bool: + """Validate token was issued for this resource server.""" + if not self.server_url or not self.resource_url: + return False # Fail if strict validation requested but URLs missing + + # Check 'aud' claim first (standard JWT audience) + aud = token_data.get("aud") + if isinstance(aud, list): + for audience in aud: + if self._is_valid_resource(audience): + return True + return False + elif aud: + return self._is_valid_resource(aud) + + # No resource binding - invalid per RFC 8707 + return False + + def _is_valid_resource(self, resource: str) -> bool: + """Check if resource matches this server using hierarchical matching.""" + if not self.resource_url: + return False + + return check_resource_allowed( + requested_resource=self.resource_url, configured_resource=resource + ) diff --git a/examples/servers/proxy-auth/pyproject.toml b/examples/servers/proxy-auth/pyproject.toml new file mode 100644 index 000000000..b0e448849 --- /dev/null +++ b/examples/servers/proxy-auth/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "proxy_auth" +version = "0.1.0" +description = "OAuth Proxy Server" +authors = [{ name = "Your Name" }] +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "mcp", +] + +[project.optional-dependencies] +dev = [ + "pytest>=6.0", +] + +[project.scripts] +mcp-proxy-auth-rs = "proxy_auth.resource_server:main" +mcp-proxy-auth-as = "proxy_auth.auth_server:main" +mcp-proxy-auth-combo = "proxy_auth.combo_server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["proxy_auth"] + +[tool.pyright] +include = ["proxy_auth"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py311" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] +extras = ["dev"] + +[[tool.uv.index]] +url = "https://pypi.org/simple" \ No newline at end of file diff --git a/examples/servers/proxy-auth/tests/__init__.py b/examples/servers/proxy-auth/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/proxy-auth/tests/test_proxy_oauth_endpoints.py b/examples/servers/proxy-auth/tests/test_proxy_oauth_endpoints.py new file mode 100644 index 000000000..10c61df25 --- /dev/null +++ b/examples/servers/proxy-auth/tests/test_proxy_oauth_endpoints.py @@ -0,0 +1,248 @@ +# pyright: reportMissingImports=false +# pytest test suite for proxy_auth/combo_server.py +# These tests spin up the FastMCP Starlette application in-process and +# exercise the custom HTTP routes as well as the `user_info` tool. + +from __future__ import annotations + +import base64 +import json +import urllib.parse +from collections.abc import AsyncGenerator +from typing import Any + +import httpx # type: ignore +import pytest # type: ignore + +# Import constants at the module level +from proxy_auth.combo_server import ( + CLIENT_ID, + UPSTREAM_AUTHORIZE, + UPSTREAM_BASE, + UPSTREAM_TOKEN, +) + + +@pytest.fixture +def proxy_server(monkeypatch): + """Import the proxy OAuth demo server with safe environment + stubs.""" + + import os + + # Avoid real outbound calls by pretending the upstream endpoints were + # supplied explicitly via env vars – this makes `fetch_upstream_metadata` + # construct metadata locally instead of performing an HTTP GET. + os.environ.setdefault( + "UPSTREAM_AUTHORIZATION_ENDPOINT", "https://upstream.example.com/authorize" + ) + os.environ.setdefault( + "UPSTREAM_TOKEN_ENDPOINT", "https://upstream.example.com/token" + ) + os.environ.setdefault("UPSTREAM_JWKS_URI", "https://upstream.example.com/jwks") + os.environ.setdefault("UPSTREAM_CLIENT_ID", "client123") + os.environ.setdefault("UPSTREAM_CLIENT_SECRET", "secret123") + + # Deferred import so the env vars above are in effect. + # Stub library-level fetch_upstream_metadata to avoid network I/O. + from mcp.server.auth.proxy import routes as proxy_routes + + # Import the module and the combo_server instance + from proxy_auth import combo_server + + async def _fake_metadata() -> dict[str, Any]: # noqa: D401 + # Access module-level constants directly + return { + "issuer": UPSTREAM_BASE, + "authorization_endpoint": UPSTREAM_AUTHORIZE, + "token_endpoint": UPSTREAM_TOKEN, + "registration_endpoint": "/register", + "jwks_uri": "", + } + + monkeypatch.setattr( + proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True + ) + + # Return the combo_server instance + return combo_server + + +@pytest.fixture +def app(proxy_server): + """Return the Starlette ASGI app for tests.""" + return proxy_server.streamable_http_app() + + +@pytest.fixture +async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]: + """Async HTTP client bound to the in-memory ASGI application.""" + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver" + ) as c: + yield c + + +# --------------------------------------------------------------------------- +# HTTP endpoint tests +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_metadata_endpoint(client): + r = await client.get("/.well-known/oauth-authorization-server") + assert r.status_code == 200 + data = r.json() + assert "issuer" in data + assert data["authorization_endpoint"].endswith("/authorize") + assert data["token_endpoint"].endswith("/token") + assert data["registration_endpoint"].endswith("/register") + + +@pytest.mark.anyio +async def test_registration_endpoint(client): + payload = {"redirect_uris": ["https://client.example.com/callback"]} + r = await client.post("/register", json=payload) + assert r.status_code == 201 + body = r.json() + assert body["client_id"] == CLIENT_ID + assert body["redirect_uris"] == payload["redirect_uris"] + # client_secret may be None, but the field should exist (masked or real) + assert "client_secret" in body + + +@pytest.mark.anyio +async def test_authorize_redirect(client): + params = { + "response_type": "code", + "state": "xyz", + "redirect_uri": "https://client.example.com/callback", + "client_id": CLIENT_ID, + "code_challenge": "testchallenge", + "code_challenge_method": "S256", + } + r = await client.get("/authorize", params=params, follow_redirects=False) + assert r.status_code in {302, 307} + + location = r.headers["location"] + parsed = urllib.parse.urlparse(location) + assert parsed.scheme.startswith("http") + assert parsed.netloc == urllib.parse.urlparse(UPSTREAM_AUTHORIZE).netloc + + qs = urllib.parse.parse_qs(parsed.query) + # Proxy should inject client_id & default scope + assert qs["client_id"][0] == CLIENT_ID + assert "scope" in qs + # Original params preserved + assert qs["state"][0] == "xyz" + + +@pytest.mark.anyio +async def test_revoke_proxy(client, monkeypatch): + original_post = httpx.AsyncClient.post + + async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401 + if url.endswith("/revoke"): + return httpx.Response(200, json={"revoked": True}) + # For the test client's own request to /revoke, + # delegate to original implementation + return await original_post(self, url, data=data, timeout=timeout, **kwargs) + + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True) + + r = await client.post("/revoke", data={"token": "dummy"}) + assert r.status_code == 200 + assert r.json() == {"revoked": True} + + +@pytest.mark.anyio +async def test_token_passthrough(client, monkeypatch): + """Ensure /token is proxied unchanged and response is returned verbatim.""" + + # Capture outgoing POSTs made by ProxyTokenHandler + captured: dict[str, Any] = {} + + original_post = httpx.AsyncClient.post + + async def _mock_post(self, url, *args, **kwargs): # noqa: D401 + if str(url).startswith(UPSTREAM_TOKEN): + # Record exactly what was sent upstream + captured["url"] = str(url) + captured["data"] = kwargs.get("data") + # Return a dummy upstream response + return httpx.Response( + 200, + json={ + "access_token": "xyz", + "token_type": "bearer", + "expires_in": 3600, + }, + ) + # Delegate any other POSTs to the real implementation + return await original_post(self, url, *args, **kwargs) + + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_post, raising=True) + + # ---------------- Act ---------------- + form = { + "grant_type": "authorization_code", + "code": "dummy-code", + "client_id": CLIENT_ID, + } + r = await client.post("/token", data=form) + + # ---------------- Assert ------------- + assert r.status_code == 200 + assert r.json()["access_token"] == "xyz" + + # Verify the request payload was forwarded without modification + assert captured["data"] == form + + +# --------------------------------------------------------------------------- +# Tool invocation – user_info +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_user_info_tool(monkeypatch, proxy_server): + """Call the `user_info` tool directly with a mocked access token.""" + # Craft a dummy JWT with useful claims (header/payload/signature parts) + payload = ( + base64.urlsafe_b64encode( + json.dumps( + { + "sub": "test-user", + "preferred_username": "tester", + } + ).encode() + ) + .decode() + .rstrip("=") + ) + dummy_token = f"header.{payload}.signature" + + from mcp.server.auth.middleware import auth_context + from mcp.server.auth.provider import AccessToken # local import to avoid cycles + + def _fake_get_access_token(): # noqa: D401 + return AccessToken( + token=dummy_token, client_id="client123", scopes=["openid"], expires_at=None + ) + + monkeypatch.setattr( + auth_context, "get_access_token", _fake_get_access_token, raising=True + ) + + result = await proxy_server.call_tool("user_info", {}) + + # call_tool returns (content_blocks, raw_result) + if isinstance(result, tuple): + _, raw = result + else: + raw = result # fallback + + assert raw["authenticated"] is True + assert ("userid" in raw and raw["userid"] == "test-user") or ( + "user_id" in raw and raw["user_id"] == "test-user" + ) + assert raw["username"] == "tester" diff --git a/examples/servers/proxy-auth/uv.lock b/examples/servers/proxy-auth/uv.lock new file mode 100644 index 000000000..6b34d73f9 --- /dev/null +++ b/examples/servers/proxy-auth/uv.lock @@ -0,0 +1,83 @@ +version = 1 +revision = 1 +requires-python = ">=3.10" + +[[package]] +name = "mcp" +version = "0.1.0" +source = { workspace = true } + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" } +] + +[[package]] +name = "pyright" +version = "1.1.391" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" } +] + +[[package]] +name = "ruff" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "dotenv" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "starlette" +version = "0.36.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" } +] + +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" } +] + +[[package]] +name = "idna" +version = "3.6" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "sniffio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "packaging" +version = "24.0" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "pluggy" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } + +[[package]] +name = "nodeenv" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..a39e51eb7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +pythonpath = . +testpaths = tests +python_files = test_*.py \ No newline at end of file diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 0a371610b..76af2150b 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,3 +1,4 @@ +# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportUnknownMemberType=false import json import logging from collections.abc import AsyncGenerator diff --git a/src/mcp/server/auth/__init__.py b/src/mcp/server/auth/__init__.py index 6888ffe8d..92d996367 100644 --- a/src/mcp/server/auth/__init__.py +++ b/src/mcp/server/auth/__init__.py @@ -1,3 +1,47 @@ +# pyright: reportGeneralTypeIssues=false """ MCP OAuth server authorization components. """ + +# Convenience re-exports so users can simply:: +# +# from mcp.server.auth import build_proxy_server +# +# instead of digging into the sub-package path. + +from typing import TYPE_CHECKING + +from mcp.server.auth.proxy import ( + create_proxy_routes, + fetch_upstream_metadata, +) +from mcp.server.fastmcp.utilities.logging import configure_logging + +# For *build_proxy_server* we need a lazy import to avoid a circular reference +# during the initial package import sequence (FastMCP -> auth -> proxy -> +# FastMCP ...). PEP 562 allows us to implement module-level `__getattr__` for +# this purpose. + + +def __getattr__(name: str): # noqa: D401 + if name == "build_proxy_server": + from mcp.server.auth.proxy.server import build_proxy_server as _bps # noqa: WPS433 + + globals()["build_proxy_server"] = _bps + return _bps + raise AttributeError(name) + + +# --------------------------------------------------------------------------- +# Public API specification +# --------------------------------------------------------------------------- + +__all__: list[str] = [ + "configure_logging", + "create_proxy_routes", + "fetch_upstream_metadata", + "build_proxy_server", +] + +if TYPE_CHECKING: # pragma: no cover – make *build_proxy_server* visible to type checkers + from mcp.server.auth.proxy.server import build_proxy_server # noqa: F401 diff --git a/src/mcp/server/auth/providers/transparent_proxy.py b/src/mcp/server/auth/providers/transparent_proxy.py new file mode 100644 index 000000000..62c8724e1 --- /dev/null +++ b/src/mcp/server/auth/providers/transparent_proxy.py @@ -0,0 +1,503 @@ +# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportUnknownArgumentType=false, reportCallIssue=false, reportUnnecessaryIsInstance=false +from __future__ import annotations + +import logging +import os +import time +import uuid +from collections.abc import Mapping +from typing import Any, cast +from urllib.parse import urlencode + +import httpx # type: ignore +from pydantic import AnyHttpUrl, AnyUrl, Field +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route + +from mcp.server.auth.handlers.token import TokenHandler +from mcp.server.auth.middleware.client_auth import ClientAuthenticator +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, +) +from mcp.server.auth.proxy.routes import create_proxy_routes +from mcp.server.auth.routes import cors_middleware, create_auth_routes +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.utilities.logging import redact_sensitive_data +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +"""Transparent OAuth proxy provider for FastMCP (Anthropic SDK). + +This provider mimics the behaviour of fastapi_mcp's `setup_proxies=True` and the +`TransparentOAuthProxyProvider` from the `fastmcp` fork. It forwards all real +OAuth traffic (authorize / token / jwks) to an upstream Authorization Server +(AS) while *locally* implementing Dynamic Client Registration so that MCP +clients such as Cursor can register even when the upstream AS disables RFC 7591 +registration. + +Environment variables (all optional – if omitted fall back to sensible defaults +or raise clearly): + +UPSTREAM_AUTHORIZATION_ENDPOINT Full URL of the upstream `/authorize` endpoint +UPSTREAM_TOKEN_ENDPOINT Full URL of the upstream `/token` endpoint +UPSTREAM_JWKS_URI URL of the upstream JWKS (optional, not yet used) +UPSTREAM_CLIENT_ID Fixed client_id registered with the upstream +UPSTREAM_CLIENT_SECRET Fixed secret (omit for public client) + +PROXY_DEFAULT_SCOPE Space-separated default scope (default: "openid") + +A simple helper ``TransparentOAuthProxyProvider.from_env()`` reads these vars. +""" + +__all__ = ["TransparentOAuthProxyProvider"] + +logger = logging.getLogger("transparent_oauth_proxy") + + +class ProxyTokenHandler(TokenHandler): + """Token handler that simply proxies token requests to the upstream AS. + + We intentionally bypass redirect_uri and PKCE checks that the normal + ``TokenHandler`` performs because in *transparent proxy* mode we do not + have enough information locally. Instead of validating, we forward the + form untouched to the upstream token endpoint and stream the response + back to the caller. + """ + + def __init__(self, provider: TransparentOAuthProxyProvider): + # We provide a dummy ClientAuthenticator that will accept any client – + # we are not going to invoke the base-class logic anyway. + super().__init__(provider=provider, client_authenticator=ClientAuthenticator(provider)) + self.provider = provider # keep for easy access + self.settings = provider.get_settings() # store settings for easier access + + async def handle(self, request) -> Response: # type: ignore[override] + correlation_id = str(uuid.uuid4())[:8] + start_time = time.time() + + logger.info(f"[{correlation_id}] 🔄 ProxyTokenHandler - passthrough") + + try: + form = await request.form() + form_dict = dict(form) + + redacted_form = redact_sensitive_data(form_dict) + logger.debug(f"[{correlation_id}] ➡︎ Incoming form: {redacted_form}") + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "User-Agent": "MCP-TransparentProxy/1.0", + } + + http = self.provider.http_client + logger.info(f"[{correlation_id}] ⮕ Forwarding to {self.settings.upstream_token}") + upstream_resp = await http.post(str(self.settings.upstream_token), data=form_dict, headers=headers) + + except httpx.HTTPError as exc: + logger.error(f"[{correlation_id}] ✗ Upstream HTTP error: {exc}") + return Response( + content='{"error":"server_error","error_description":"Upstream server error"}', + status_code=502, + headers={"Content-Type": "application/json"}, + ) + except Exception as exc: + logger.error(f"[{correlation_id}] ✗ Unexpected proxy error: {exc}") + return Response( + content='{"error":"server_error"}', + status_code=500, + headers={"Content-Type": "application/json"}, + ) + + finally: + elapsed = time.time() - start_time + logger.info(f"[{correlation_id}] ⏱ Finished in {elapsed:.2f}s") + + # Log upstream response (redacted) + try: + if upstream_resp.headers.get("content-type", "").startswith("application/json"): + body = upstream_resp.json() + logger.info( + f"[{correlation_id}] ⬅︎ Body: {redact_sensitive_data(body) if isinstance(body, dict) else body}" + ) + except Exception: + pass + + return Response( + content=upstream_resp.content, + status_code=upstream_resp.status_code, + headers=dict(upstream_resp.headers), + ) + + +class ProxyIntrospectionHandler: + """Handler for token introspection endpoint. + + Resource Servers call this endpoint to validate tokens without + needing direct access to token storage. + """ + + def __init__(self, provider: TransparentOAuthProxyProvider, client_id: str, default_scope: str): + self.provider = provider + self.client_id = client_id + self.default_scope = default_scope + + async def handle(self, request: Request) -> Response: + """ + Token introspection endpoint for Resource Servers. + """ + form = await request.form() + token = form.get("token") + if not token or not isinstance(token, str): + return JSONResponse({"active": False}, status_code=400) + + # For the transparent proxy, we don't actually validate tokens + # Just create a dummy AccessToken like the provider does + access_token = AccessToken(token=token, client_id=self.client_id, scopes=[self.default_scope], expires_at=None) + + return JSONResponse( + { + "active": True, + "client_id": access_token.client_id, + "scope": " ".join(access_token.scopes), + "exp": access_token.expires_at, + "iat": int(time.time()), + "token_type": "Bearer", + "aud": access_token.resource, # RFC 8707 audience claim + } + ) + + +class ProxyRegistrationHandler: + """Handler for client registration endpoint. + + This handler implements a simplified version of Dynamic Client Registration + that always returns the upstream client credentials. + """ + + def __init__(self, provider: TransparentOAuthProxyProvider): + self.provider = provider + # Store settings for easier access + self.settings = provider.get_settings() + + async def handle(self, request: Request) -> Response: + """ + Client registration endpoint that returns upstream credentials. + """ + correlation_id = str(uuid.uuid4())[:8] + logger.info(f"[{correlation_id}] 🔑 ProxyRegistrationHandler - registration request") + + try: + body = await request.json() + + # Log the incoming request body (redacted) + redacted_body = redact_sensitive_data(body) + logger.info(f"[{correlation_id}] ➡︎ Incoming registration request: {redacted_body}") + + # Create response with upstream credentials + client_metadata = { + "client_id": str(self.settings.client_id), + "client_secret": self.settings.client_secret, + "token_endpoint_auth_method": "none" if self.settings.client_secret is None else "client_secret_post", + **body, # Include original request fields + } + + # Log the client ID we're returning + logger.info(f"[{correlation_id}] ⬅︎ Returning client_id: {self.settings.client_id}") + + return JSONResponse(client_metadata, status_code=201) + + except Exception as exc: + logger.error(f"[{correlation_id}] ✗ Registration error: {exc}") + return JSONResponse( + {"error": "invalid_client_metadata", "error_description": str(exc)}, + status_code=400, + ) + + +class ProxySettings(BaseSettings): + """Validated environment-driven settings for the transparent OAuth proxy.""" + + model_config = SettingsConfigDict(env_file=".env", populate_by_name=True, extra="ignore") + + upstream_authorize: AnyHttpUrl = Field(..., alias="UPSTREAM_AUTHORIZATION_ENDPOINT") + upstream_token: AnyHttpUrl = Field(..., alias="UPSTREAM_TOKEN_ENDPOINT") + jwks_uri: str | None = Field(None, alias="UPSTREAM_JWKS_URI") + + client_id: str | None = Field(None, alias="UPSTREAM_CLIENT_ID") + client_secret: str | None = Field(None, alias="UPSTREAM_CLIENT_SECRET") + + # Allow overriding via env var, but default to "openid" if not provided + default_scope: str = Field("openid", alias="PROXY_DEFAULT_SCOPE") + + @classmethod + def load(cls) -> ProxySettings: + """Instantiate settings from environment variables (for backwards compatibility).""" + return cls() + + +# Backwards-compatibility alias – existing callers/tests import `_Settings` +_Settings = ProxySettings # type: ignore + + +class TransparentOAuthProxyProvider(OAuthAuthorizationServerProvider[AuthorizationCode, Any, AccessToken]): + """Minimal pass-through provider – only implements code flow, no refresh.""" + + def __init__(self, *, settings: ProxySettings, auth_settings: AuthSettings): + # Fill in client_id fallback if not provided via upstream var + if settings.client_id is None: + settings.client_id = os.getenv("PROXY_CLIENT_ID", "demo-client-id") # type: ignore[assignment] + assert settings.client_id is not None, "client_id must be provided" + self._s = settings + self._auth_settings = auth_settings + # simple in-memory auth-code store (maps code→AuthorizationCode) + self._codes: dict[str, AuthorizationCode] = {} + # always the same client info returned by /register + self._static_client = OAuthClientInformationFull( + client_id=str(self._s.client_id), + client_secret=self._s.client_secret, + redirect_uris=[cast(AnyUrl, cast(object, "http://localhost"))], + grant_types=["authorization_code"], + token_endpoint_auth_method="none" if self._s.client_secret is None else "client_secret_post", + ) + + # Single reusable HTTP client for communicating with the upstream AS + self._http: httpx.AsyncClient = httpx.AsyncClient(timeout=15) + + def get_settings(self) -> ProxySettings: + """Return the provider's settings.""" + return self._s + + # Expose http client for handlers + @property + def http_client(self) -> httpx.AsyncClient: # noqa: D401 + return self._http + + async def aclose(self) -> None: + """Close the underlying HTTP client.""" + await self._http.aclose() + + # --------------------------------------------------------------------- + # Dynamic Client Registration – always enabled + # --------------------------------------------------------------------- + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: # noqa: D401 + logger.info(f"🔍 get_client called with client_id: {client_id}") + logger.info(f"Expected client_id from settings: {self._s.client_id}") + result = self._static_client if client_id == self._s.client_id else None + logger.info(f"Client found: {result is not None}") + return result + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: # noqa: D401 + """Spoof DCR: overwrite the incoming info with fixed credentials.""" + + logger.info("🔑 register_client method called in TransparentOAuthProxyProvider") + logger.info(f"Original client_id: {client_info.client_id}") + + client_info.client_id = str(self._s.client_id) + client_info.client_secret = self._s.client_secret + # Ensure token_endpoint_auth_method reflects whether secret exists + client_info.token_endpoint_auth_method = "none" if self._s.client_secret is None else "client_secret_post" + # Replace stored static client redirect URIs with provided ones so later validation passes + self._static_client.redirect_uris = client_info.redirect_uris + + logger.info(f"Modified client_id to: {client_info.client_id}") + if self._s.client_secret: + logger.info("Set client_secret from settings") + + return None + + # ------------------------------------------------------------------ + # Authorization endpoint – redirect to upstream + # ------------------------------------------------------------------ + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # noqa: D401 + query: dict[str, str | None] = { + "response_type": "code", + "client_id": str(self._s.client_id), + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "code_challenge_method": "S256", + "scope": " ".join(params.scopes or [self._s.default_scope]), + "state": params.state, + } + return f"{self._s.upstream_authorize}?{urlencode({k: v for k, v in query.items() if v})}" + + # ------------------------------------------------------------------ + # Auth-code tracking / exchange + # ------------------------------------------------------------------ + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: # noqa: D401,E501 + # create lightweight object; we cannot verify with upstream at this stage + return AuthorizationCode( + code=authorization_code, + scopes=[self._s.default_scope], + expires_at=int(time.time() + 300), + client_id=str(self._s.client_id), + redirect_uri=cast(AnyUrl, cast(object, "http://localhost")), # type: ignore[arg-type] + redirect_uri_provided_explicitly=False, + code_challenge="", # not validated here + ) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: # noqa: D401,E501 + # Generate correlation ID for this request + correlation_id = str(uuid.uuid4())[:8] + start_time = time.time() + + logger.info(f"[{correlation_id}] Starting token exchange for client_id={client.client_id}") + + data: dict[str, str] = { + "grant_type": "authorization_code", + "client_id": str(self._s.client_id), + "code": authorization_code.code, + "redirect_uri": str(authorization_code.redirect_uri), + } + if self._s.client_secret: + data["client_secret"] = self._s.client_secret + + # Log outgoing request with full details + redacted_data = redact_sensitive_data(data) + logger.info(f"[{correlation_id}] ⮕ Preparing upstream token request") + logger.info(f"[{correlation_id}] ⮕ Target URL: {self._s.upstream_token}") + logger.info(f"[{correlation_id}] ⮕ Request data: {redacted_data}") + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "User-Agent": "MCP-TransparentProxy/1.0", + } + logger.info(f"[{correlation_id}] ⮕ Request headers: {headers}") + + http = self.http_client + try: + logger.info(f"[{correlation_id}] ⮕ Sending POST request to upstream") + resp = await http.post(str(self._s.upstream_token), data=data, headers=headers) + + elapsed_time = time.time() - start_time + logger.info(f"[{correlation_id}] ⬅︎ Upstream response received in {elapsed_time:.2f}s") + logger.info(f"[{correlation_id}] ⬅︎ Status: {resp.status_code}") + logger.info(f"[{correlation_id}] ⬅︎ Headers: {dict(resp.headers)}") + + # Log response body (redacted) + try: + body = resp.json() + redacted_body = redact_sensitive_data(body) if isinstance(body, dict) else body + logger.info(f"[{correlation_id}] ⬅︎ Response body: {redacted_body}") + except Exception as e: + logger.warning(f"[{correlation_id}] ⬅︎ Could not parse response as JSON: {e}") + logger.info(f"[{correlation_id}] ⬅︎ Raw response: {resp.text[:500]}...") + + resp.raise_for_status() + + except httpx.HTTPError as e: + logger.error(f"[{correlation_id}] ⬅︎ HTTP error occurred: {e}") + raise + except Exception as e: + logger.error(f"[{correlation_id}] ⬅︎ Unexpected error: {e}") + raise + + body: Mapping[str, Any] = resp.json() + logger.info(f"[{correlation_id}] ✓ Token exchange completed successfully") + return OAuthToken(**body) # type: ignore[arg-type] + + # ------------------------------------------------------------------ + # Unused grant types + # ------------------------------------------------------------------ + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str): # noqa: D401 + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: str, + scopes: list[str], + ) -> OAuthToken: # noqa: D401 + raise NotImplementedError + + async def load_access_token(self, token: str) -> AccessToken | None: # noqa: D401 + # For now we cannot validate JWT; return a dummy AccessToken so BearerAuth passes. + return AccessToken( + token=token, client_id=str(self._s.client_id), scopes=[self._s.default_scope], expires_at=None + ) + + async def revoke_token(self, token: object) -> None: # noqa: D401 + return None + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @classmethod + def from_env(cls) -> TransparentOAuthProxyProvider: + """Construct provider using :class:`ProxySettings` populated from the environment.""" + return cls(settings=ProxySettings.load()) + + # FastMCP will read `client_registration_options` to decide whether to expose /register + @property + def client_registration_options(self) -> ClientRegistrationOptions: # type: ignore[override] + return ClientRegistrationOptions(enabled=True) + + # ------------------------------------------------------------------ + # Provide custom auth routes so that our proxy /token endpoint overrides the default one + # ------------------------------------------------------------------ + + def get_auth_routes(self): # type: ignore[override] + """Return full auth+proxy route list for FastMCP.""" + + routes = create_auth_routes( + provider=self, + issuer_url=self._auth_settings.issuer_url, + client_registration_options=self.client_registration_options, + revocation_options=None, + service_documentation_url=None, + ) + + # Drop default /token, /authorize, and /register handlers – we provide custom ones. + routes = [r for r in routes if not (isinstance(r, Route) and r.path in {"/token", "/authorize", "/register"})] + + # Insert proxy /token handler first for high precedence + proxy_token_handler = ProxyTokenHandler(self) + routes.insert(0, Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"])) + + # Add registration endpoint + proxy_registration_handler = ProxyRegistrationHandler(self) + routes.insert(1, Route("/register", endpoint=proxy_registration_handler.handle, methods=["POST"])) + + # Add introspection endpoint + proxy_introspection_handler = ProxyIntrospectionHandler( + provider=self, client_id=str(self._s.client_id), default_scope=self._s.default_scope + ) + routes.insert( + 1, + Route( + "/introspect", + endpoint=cors_middleware(proxy_introspection_handler.handle, ["POST", "OPTIONS"]), + methods=["POST", "OPTIONS"], + ), + ) + + # Get proxy routes but filter out any that would conflict with our custom handlers + proxy_routes = create_proxy_routes(self) + proxy_routes = [ + r for r in proxy_routes if not (isinstance(r, Route) and r.path in {"/token", "/register", "/introspect"}) + ] + + # Log the final route configuration + logger.debug("Final route configuration:") + for r in routes + proxy_routes: + if isinstance(r, Route): + logger.debug(f" {r.path} - Methods: {r.methods}") + + # Append additional proxy endpoints (metadata, authorize, revoke…) + routes.extend(proxy_routes) + + return routes diff --git a/src/mcp/server/auth/proxy/__init__.py b/src/mcp/server/auth/proxy/__init__.py new file mode 100644 index 000000000..b926a733a --- /dev/null +++ b/src/mcp/server/auth/proxy/__init__.py @@ -0,0 +1,30 @@ +"""Transparent OAuth proxy helpers (library form). + +This sub-package turns the demo-level transparent OAuth proxy into a reusable +component: + +* create_proxy_routes(provider) – returns the Starlette routes that expose the + proxy endpoints (/authorize, /revoke …). +* build_proxy_server() – convenience helper that wires everything into a + FastMCP instance. + +The functions are re-exported here so users can simply:: + + from mcp.server.auth.proxy import build_proxy_server + +""" + +from __future__ import annotations + +# Public re-exports +from .routes import create_proxy_routes, fetch_upstream_metadata + +__all__: list[str] = [ + "create_proxy_routes", + "fetch_upstream_metadata", +] + +# build_proxy_server intentionally *not* imported here to avoid circular +# imports with TransparentOAuthProxyProvider. Import from +# `mcp.server.auth.proxy.server` when needed: +# from mcp.server.auth.proxy.server import build_proxy_server diff --git a/src/mcp/server/auth/proxy/routes.py b/src/mcp/server/auth/proxy/routes.py new file mode 100644 index 000000000..9d4be23da --- /dev/null +++ b/src/mcp/server/auth/proxy/routes.py @@ -0,0 +1,175 @@ +# pyright: reportGeneralTypeIssues=false +"""Starlette routes that implement the transparent OAuth proxy endpoints.""" + +from __future__ import annotations + +import logging +import urllib.parse +from typing import Any + +import httpx # type: ignore +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.routing import Route + +from mcp.server.fastmcp.utilities.logging import configure_logging + +__all__ = ["fetch_upstream_metadata", "create_proxy_routes"] + +logger = logging.getLogger("transparent_oauth_proxy.routes") + + +# --------------------------------------------------------------------------- +# Helper – fetch (or synthesise) upstream AS metadata +# --------------------------------------------------------------------------- + + +async def fetch_upstream_metadata( # noqa: D401 + upstream_base: str, + upstream_authorize: str, + upstream_token: str, + upstream_jwks_uri: str | None = None, +) -> dict[str, Any]: + """Return upstream metadata, mirroring logic from old server.py.""" + + # If explicit endpoints provided, craft a synthetic metadata object. + if upstream_authorize and upstream_token: + return { + "issuer": upstream_base, + "authorization_endpoint": upstream_authorize, + "token_endpoint": upstream_token, + "registration_endpoint": "/register", + "jwks_uri": upstream_jwks_uri or "", + } + + # Otherwise attempt remote fetch. + metadata_url = f"{upstream_base}/.well-known/oauth-authorization-server" + try: + async with httpx.AsyncClient() as client: + r = await client.get(metadata_url, timeout=10) + r.raise_for_status() + return r.json() + except Exception as exc: # noqa: BLE001 + logger.warning("Could not fetch upstream metadata (%s); using fallback.", exc) + return { + "issuer": "fallback", + "authorization_endpoint": "/authorize", + "token_endpoint": "/token", + "registration_endpoint": "/register", + } + + +# --------------------------------------------------------------------------- +# Route factory – returns Starlette Route objects +# --------------------------------------------------------------------------- + + +def create_proxy_routes(provider: Any) -> list[Route]: # type: ignore[valid-type] + """Create all additional proxy-specific routes. + + The *provider* must be an instance of + `TransparentOAuthProxyProvider` (duck-typed here to avoid circular imports). + """ + + configure_logging() # ensure log format if not already set + + s = provider._s # access its validated settings (_Settings) + + # Introduce a dedicated handler class to avoid nested closures while still + # retaining the convenience of accessing validated settings via + # ``self.s``. This improves introspection, simplifies debugging and makes + # future extensibility (e.g. dependency injection) easier. + + class _ProxyHandlers: # noqa: D401,E501 + """Collection of async endpoints implementing the proxy logic.""" + + def __init__(self, settings: Any): # type: ignore[valid-type] + self.s = settings + + # ------------------------------------------------------------------ + # /.well-known/oauth-authorization-server + # ------------------------------------------------------------------ + async def metadata(self, request: Request) -> Response: # noqa: D401 + logger.info("🔍 /.well-known/oauth-authorization-server endpoint accessed") + + data = await fetch_upstream_metadata( + self.s.upstream_authorize.rsplit("/", 1)[0], # base + str(self.s.upstream_authorize), + str(self.s.upstream_token), + self.s.jwks_uri, + ) + + host = request.headers.get("host", "localhost") + scheme = "https" if request.url.scheme == "https" else "http" + issuer = f"{scheme}://{host}" + data.update( + { + "issuer": issuer, + "authorization_endpoint": f"{issuer}/authorize", + "token_endpoint": f"{issuer}/token", + "registration_endpoint": f"{issuer}/register", + } + ) + return JSONResponse(data) + + # ------------------------------------------------------------------ + # /register – Dynamic Client Registration stub + # ------------------------------------------------------------------ + async def register(self, request: Request) -> Response: # noqa: D401 + logger.info("🔑 /register endpoint accessed - custom proxy implementation") + body = await request.json() + + # Log the incoming request body (redacted) + from mcp.server.fastmcp.utilities.logging import redact_sensitive_data + + redacted_body = redact_sensitive_data(body) + logger.info(f"Incoming registration request: {redacted_body}") + + client_metadata = { + "client_id": self.s.client_id, + "client_secret": self.s.client_secret, + "token_endpoint_auth_method": "client_secret_post" if self.s.client_secret else "none", + **body, + } + + # Log the client ID we're returning + logger.info(f"Returning client_id: {self.s.client_id}") + if self.s.client_secret: + logger.info("Including client_secret in response") + + return JSONResponse(client_metadata, status_code=201) + + # ------------------------------------------------------------------ + # /authorize – Redirect to upstream with injections + # ------------------------------------------------------------------ + async def authorize(self, request: Request) -> Response: # noqa: D401 + params = dict(request.query_params) + params["client_id"] = self.s.client_id + if "scope" not in params: + params["scope"] = self.s.default_scope + + redirect_url = f"{self.s.upstream_authorize}?{urllib.parse.urlencode(params)}" + return RedirectResponse(redirect_url) + + # ------------------------------------------------------------------ + # /revoke – Pass-through + # ------------------------------------------------------------------ + async def revoke(self, request: Request) -> Response: # noqa: D401 + form = await request.form() + data = dict(form) + data.setdefault("client_id", self.s.client_id) + if self.s.client_secret: + data.setdefault("client_secret", self.s.client_secret) + + async with httpx.AsyncClient() as client: + r = await client.post(str(self.s.upstream_token).rsplit("/", 1)[0] + "/revoke", data=data, timeout=10) + return JSONResponse(r.json(), status_code=r.status_code) + + handlers = _ProxyHandlers(s) + + return [ + Route("/.well-known/oauth-authorization-server", handlers.metadata, methods=["GET"]), + Route("/register", handlers.register, methods=["POST"]), + Route("/authorize", handlers.authorize, methods=["GET"]), + Route("/revoke", handlers.revoke, methods=["POST"]), + ] diff --git a/src/mcp/server/auth/proxy/server.py b/src/mcp/server/auth/proxy/server.py new file mode 100644 index 000000000..3b9a4ece5 --- /dev/null +++ b/src/mcp/server/auth/proxy/server.py @@ -0,0 +1,63 @@ +# pyright: reportPrivateUsage=false, reportUnknownParameterType=false +"""Convenience helper for spinning up a FastMCP Transparent OAuth proxy server.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from pydantic import AnyHttpUrl + +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.utilities.logging import configure_logging + +from ..providers.transparent_proxy import TransparentOAuthProxyProvider + +if TYPE_CHECKING: # pragma: no cover – typing-only imports + from mcp.server.auth.providers.transparent_proxy import _Settings as ProxySettings + +__all__ = ["build_proxy_server"] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def build_proxy_server( # noqa: D401,E501 + *, + host: str = "0.0.0.0", + port: int = 8000, + issuer_url: str | None = None, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "DEBUG", + settings: ProxySettings | None = None, +) -> FastMCP: + """Return a fully-configured FastMCP instance running the proxy. + + Prefer passing a fully-validated *settings* object (instance of + :class:`mcp.server.auth.providers.transparent_proxy._Settings`) which makes + configuration explicit and type-checked. + """ + + # Runtime import to avoid circular dependency at module import time. + from ..providers.transparent_proxy import _Settings as ProxySettings + + if settings is None: + settings = ProxySettings.load() + + configure_logging(level=log_level) # type: ignore[arg-type] + + auth_settings = AuthSettings( + issuer_url=AnyHttpUrl(issuer_url or f"http://localhost:{port}"), # type: ignore[arg-type] + resource_server_url=AnyHttpUrl(f"http://localhost:{port}"), # type: ignore[arg-type] + required_scopes=["openid"], + client_registration_options=ClientRegistrationOptions(enabled=True), + ) + + provider = TransparentOAuthProxyProvider(settings=settings, auth_settings=auth_settings) # type: ignore[arg-type] + + mcp = FastMCP( + name="Transparent OAuth Proxy", host=host, port=port, auth_server_provider=provider, auth=auth_settings + ) + + return mcp diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2fe7c1224..f12704129 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -750,17 +750,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # Add auth endpoints if auth server provider is configured if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) + routes.extend(_build_provider_auth_routes(self._auth_server_provider, self.settings.auth)) # When auth is configured, require authentication if self._token_verifier: @@ -863,17 +853,7 @@ def streamable_http_app(self) -> Starlette: # Add auth endpoints if auth server provider is configured if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) + routes.extend(_build_provider_auth_routes(self._auth_server_provider, self.settings.auth)) # Set up routes with or without auth if self._token_verifier: @@ -1162,3 +1142,39 @@ async def warning(self, message: str, **extra: Any) -> None: async def error(self, message: str, **extra: Any) -> None: """Send an error log message.""" await self.log("error", message, **extra) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +# pyright: reportUnknownArgumentType=false, reportUnknownParameterType=false +def _build_provider_auth_routes(provider: OAuthAuthorizationServerProvider[Any, Any, Any], auth_settings: AuthSettings): + """Return the list of Starlette routes for the given provider. + + This consolidates the custom-route fallback logic that previously appeared + twice in ``sse_app`` and ``streamable_http_app``. + """ + + from mcp.server.auth.routes import create_auth_routes # local import to avoid cycles + + # Allow provider to supply its own custom route list (e.g. proxy mode) + get_auth_routes = getattr(provider, "get_auth_routes", None) + if callable(get_auth_routes): + try: + custom = get_auth_routes() + if custom and hasattr(custom, "__iter__"): + return list(custom) # type: ignore[return-value] + except Exception: + # Fall back to default factory on any error + pass + + # Default behaviour – use shared route factory + return create_auth_routes( + provider=provider, + issuer_url=auth_settings.issuer_url, + service_documentation_url=auth_settings.service_documentation_url, + client_registration_options=auth_settings.client_registration_options, + revocation_options=auth_settings.revocation_options, + ) diff --git a/src/mcp/server/fastmcp/utilities/logging.py b/src/mcp/server/fastmcp/utilities/logging.py index 091d57e69..777013b5f 100644 --- a/src/mcp/server/fastmcp/utilities/logging.py +++ b/src/mcp/server/fastmcp/utilities/logging.py @@ -1,7 +1,8 @@ """Logging utilities for FastMCP.""" import logging -from typing import Literal +from collections.abc import Mapping +from typing import Any, Literal def get_logger(name: str) -> logging.Logger: @@ -41,3 +42,52 @@ def configure_logging( format="%(message)s", handlers=handlers, ) + + +# --------------------------------------------------------------------------- +# Helper – redact sensitive data before logging +# --------------------------------------------------------------------------- + + +def redact_sensitive_data( + data: Mapping[str, Any] | None, + sensitive_keys: set[str] | None = None, +) -> Mapping[str, Any] | None: + """Return a shallow copy with sensitive values replaced by "***". + + This shared helper can be used across the code-base (e.g. the transparent + OAuth proxy) to ensure we treat secrets consistently. + + Parameters + ---------- + data: + Original mapping (typically request/response payload). If *None* the + function simply returns *None*. + sensitive_keys: + Optional set of keys that should be hidden; defaults to a common list + of OAuth-related secrets. + """ + + if data is None: + return None + + sensitive_keys = sensitive_keys or { + "client_secret", + "authorization", + "access_token", + "refresh_token", + "code", + } + + redacted: dict[str, Any] = {} + for key, value in data.items(): + if key.lower() in sensitive_keys: + # Show a short prefix of auth codes; redact everything else + if key.lower() == "code" and isinstance(value, str): + redacted[key] = value[:8] + "..." if len(value) > 8 else "***" + else: + redacted[key] = "***" + else: + redacted[key] = value + + return redacted diff --git a/tests/test_proxy_builder.py b/tests/test_proxy_builder.py new file mode 100644 index 000000000..0781274a1 --- /dev/null +++ b/tests/test_proxy_builder.py @@ -0,0 +1,49 @@ +# pyright: reportMissingImports=false, reportGeneralTypeIssues=false +"""Tests for the build_proxy_server convenience helper.""" + +from __future__ import annotations + +from typing import cast + +import httpx # type: ignore +import pytest # type: ignore +from pydantic import AnyHttpUrl + +from mcp.server.auth.providers.transparent_proxy import _Settings as ProxySettings +from mcp.server.auth.proxy import routes as proxy_routes +from mcp.server.auth.proxy.server import build_proxy_server + + +@pytest.mark.anyio +async def test_build_proxy_server_metadata(monkeypatch): + """Ensure the server starts and serves metadata without touching network.""" + + # Patch metadata fetcher so no real HTTP traffic occurs + async def _fake_metadata(): # noqa: D401 + return { + "issuer": "https://proxy.test", + "authorization_endpoint": "https://proxy.test/authorize", + "token_endpoint": "https://proxy.test/token", + "registration_endpoint": "/register", + } + + monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True) + + # Provide required upstream endpoints via settings object + settings = ProxySettings( # type: ignore[call-arg] + UPSTREAM_AUTHORIZATION_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/authorize"), + UPSTREAM_TOKEN_ENDPOINT=cast(AnyHttpUrl, "https://upstream.example.com/token"), + UPSTREAM_CLIENT_ID="demo-client-id", + UPSTREAM_CLIENT_SECRET=None, + UPSTREAM_JWKS_URI=None, + ) + + mcp = build_proxy_server(port=0, settings=settings) + + app = mcp.streamable_http_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as c: + r = await c.get("/.well-known/oauth-authorization-server") + assert r.status_code == 200 + data = r.json() + assert data["authorization_endpoint"].endswith("/authorize") diff --git a/uv.lock b/uv.lock index 7a34275ce..7090d3a4d 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ members = [ "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", "mcp-snippets", + "proxy-auth", ] [[package]] @@ -1168,6 +1169,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "proxy-auth" +version = "0.1.0" +source = { editable = "examples/servers/proxy-auth" } +dependencies = [ + { name = "mcp" }, +] + +[package.optional-dependencies] +dev = [ + { name = "pytest" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "mcp", editable = "." }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=6.0" }, +] +provides-extras = ["dev"] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "ruff", specifier = ">=0.8.5" }, +] + [[package]] name = "pycparser" version = "2.22"