Skip to content

Commit 49c7ed9

Browse files
committed
enables oauth proxy capability
Signed-off-by: Jesse Sanford <[email protected]>
1 parent 6f43d1f commit 49c7ed9

File tree

12 files changed

+1343
-23
lines changed

12 files changed

+1343
-23
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# pyright: reportMissingImports=false
2+
import os
3+
from urllib.parse import urljoin
4+
from dotenv import load_dotenv # type: ignore
5+
from typing import Any, cast
6+
import base64, json, time
7+
from starlette.requests import Request # type: ignore
8+
9+
from mcp.server.fastmcp.server import Context
10+
from mcp.server.auth.proxy.server import build_proxy_server # noqa: E402
11+
12+
# Load environment variables from .env if present
13+
load_dotenv()
14+
15+
# ----------------------------------------------------------------------------
16+
# Environment configuration
17+
# ----------------------------------------------------------------------------
18+
_upstream_base = os.getenv("PROXY_UPSTREAM_BASE", "https://auth.example.com/")
19+
# Sanitize trailing slash
20+
if _upstream_base.endswith("/"):
21+
_upstream_base = _upstream_base[:-1]
22+
UPSTREAM_BASE: str = _upstream_base
23+
24+
print("[Proxy Config] UPSTREAM_BASE:", UPSTREAM_BASE)
25+
print("[Proxy Config] CLIENT_ID:", os.getenv("PROXY_CLIENT_ID") or os.getenv("UPSTREAM_CLIENT_ID"))
26+
27+
CLIENT_ID = os.getenv("PROXY_CLIENT_ID") or os.getenv("UPSTREAM_CLIENT_ID") or "demo-client-id"
28+
CLIENT_SECRET = os.getenv("PROXY_CLIENT_SECRET") or os.getenv("UPSTREAM_CLIENT_SECRET") # may be None
29+
DEFAULT_SCOPE = os.getenv("PROXY_DEFAULT_SCOPE", "openid profile email")
30+
AUDIENCE = os.getenv("PROXY_AUDIENCE") # optional
31+
32+
# ---------------------------------------------------------------------------
33+
# Resolve upstream endpoints – prefer explicit *_ENDPOINT variables (matches
34+
# naming used in fastmcp example) and fall back to BASE + path.
35+
# ---------------------------------------------------------------------------
36+
37+
UPSTREAM_AUTHORIZE = os.getenv("UPSTREAM_AUTHORIZATION_ENDPOINT") or f"{UPSTREAM_BASE}/authorize"
38+
UPSTREAM_TOKEN = os.getenv("UPSTREAM_TOKEN_ENDPOINT") or f"{UPSTREAM_BASE}/token"
39+
UPSTREAM_JWKS_URI = os.getenv("UPSTREAM_JWKS_URI")
40+
UPSTREAM_REVOCATION = os.getenv("UPSTREAM_REVOCATION_ENDPOINT") or f"{UPSTREAM_BASE}/revoke"
41+
42+
# Metadata URL (only used if we need to fetch from upstream)
43+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
44+
45+
print("[Proxy Config] UPSTREAM_AUTHORIZE:", UPSTREAM_AUTHORIZE)
46+
print("[Proxy Config] UPSTREAM_TOKEN:", UPSTREAM_TOKEN)
47+
48+
# Server host/port
49+
PROXY_PORT = int(os.getenv("PROXY_PORT", "8000"))
50+
51+
# ----------------------------------------------------------------------------
52+
# FastMCP server (now created via library helper)
53+
# ----------------------------------------------------------------------------
54+
55+
ISSUER_URL = os.getenv("PROXY_ISSUER_URL", "http://localhost:8000")
56+
57+
# Create FastMCP instance using the reusable proxy builder
58+
mcp = build_proxy_server(port=PROXY_PORT, issuer_url=ISSUER_URL)
59+
60+
# ---------------------------------------------------------------------------
61+
# Minimal demo tool
62+
# ---------------------------------------------------------------------------
63+
64+
@mcp.tool()
65+
def echo(message: str) -> str:
66+
return f"Echo: {message}"
67+
68+
69+
@mcp.tool()
70+
async def user_info(ctx: Context[Any, Any, Request]) -> dict[str, Any]:
71+
"""
72+
Get information about the authenticated user.
73+
74+
This tool demonstrates accessing user information from the OAuth access token.
75+
The user must be authenticated via OAuth to access this tool.
76+
77+
Returns:
78+
Dictionary containing user information from the access token
79+
"""
80+
from mcp.server.auth.middleware.auth_context import get_access_token
81+
82+
# Get the access token from the authentication context
83+
access_token = get_access_token()
84+
85+
if not access_token:
86+
return {
87+
"error": "No access token found - user not authenticated",
88+
"authenticated": False
89+
}
90+
91+
# Attempt to decode the access token as JWT to extract useful user claims.
92+
# Many OAuth providers issue JWT access tokens (or ID tokens) that contain
93+
# the user's subject (sub) and preferred username. We parse the token
94+
# *without* signature verification – we only need the public claims for
95+
# display purposes. If the token is opaque or the decode fails, we simply
96+
# skip this step.
97+
98+
def _try_decode_jwt(token_str: str) -> dict[str, Any] | None: # noqa: D401
99+
"""Best-effort JWT decode without verification.
100+
101+
Returns the payload dictionary if the token *looks* like a JWT and can
102+
be base64-decoded. If anything fails we return None.
103+
"""
104+
105+
try:
106+
parts = token_str.split(".")
107+
if len(parts) != 3:
108+
return None # Not a JWT
109+
110+
# JWT parts are URL-safe base64 without padding
111+
def _b64decode(segment: str) -> bytes:
112+
padding = "=" * (-len(segment) % 4)
113+
return base64.urlsafe_b64decode(segment + padding)
114+
115+
payload_bytes = _b64decode(parts[1])
116+
return json.loads(payload_bytes)
117+
except Exception: # noqa: BLE001
118+
return None
119+
120+
jwt_claims = _try_decode_jwt(access_token.token)
121+
122+
# Build response with token information plus any extracted claims
123+
response: dict[str, Any] = {
124+
"authenticated": True,
125+
"client_id": access_token.client_id,
126+
"scopes": access_token.scopes,
127+
"token_type": "Bearer",
128+
"expires_at": access_token.expires_at,
129+
"resource": access_token.resource,
130+
}
131+
132+
if jwt_claims:
133+
# Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
134+
uid = jwt_claims.get("userid") or jwt_claims.get("sub")
135+
if uid is not None:
136+
response["userid"] = uid # camelCase variant used in FastMCP reference
137+
response["user_id"] = uid # snake_case variant
138+
response["username"] = (
139+
jwt_claims.get("preferred_username")
140+
or jwt_claims.get("nickname")
141+
or jwt_claims.get("name")
142+
)
143+
response["issuer"] = jwt_claims.get("iss")
144+
response["audience"] = jwt_claims.get("aud")
145+
response["issued_at"] = jwt_claims.get("iat")
146+
147+
# Calculate expiration helpers
148+
if access_token.expires_at:
149+
response["expires_at_iso"] = time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime(access_token.expires_at))
150+
response["expires_in_seconds"] = max(0, access_token.expires_at - int(time.time()))
151+
152+
return response
153+
154+
155+
@mcp.tool()
156+
async def test_endpoint(message: str = "Hello from proxy server!") -> dict[str, Any]:
157+
"""
158+
Test endpoint for debugging OAuth proxy functionality.
159+
160+
Args:
161+
message: Optional message to echo back
162+
163+
Returns:
164+
Test response with server information
165+
"""
166+
return {
167+
"message": message,
168+
"server": "Transparent OAuth Proxy Server",
169+
"status": "active",
170+
"oauth_configured": True
171+
}
172+
173+
174+
if __name__ == "__main__":
175+
mcp.run(transport="streamable-http")
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script to demonstrate the enhanced logging features of the transparent OAuth proxy.
4+
This script makes requests to various endpoints to show the comprehensive logging.
5+
"""
6+
7+
import time
8+
import requests
9+
import json
10+
import sys
11+
from threading import Thread
12+
import subprocess
13+
import os
14+
import signal
15+
16+
def test_endpoints():
17+
"""Test various endpoints to demonstrate logging."""
18+
base_url = "http://localhost:8000"
19+
20+
print("\n" + "="*60)
21+
print("🧪 TESTING ENHANCED LOGGING FEATURES")
22+
print("="*60)
23+
24+
# Wait for server to be ready
25+
print("⏳ Waiting for server to be ready...")
26+
time.sleep(3)
27+
28+
try:
29+
# Test 1: Metadata discovery
30+
print("\n🔍 Testing OAuth metadata discovery...")
31+
response = requests.get(f"{base_url}/.well-known/oauth-authorization-server",
32+
headers={"Host": "localhost:8000"})
33+
print(f" Status: {response.status_code}")
34+
35+
# Test 2: Client registration (DCR)
36+
print("\n📝 Testing Dynamic Client Registration...")
37+
registration_data = {
38+
"redirect_uris": ["http://localhost:3000/callback"],
39+
"grant_types": ["authorization_code"],
40+
"response_types": ["code"],
41+
"client_name": "Test MCP Client"
42+
}
43+
response = requests.post(f"{base_url}/register",
44+
json=registration_data,
45+
headers={"Content-Type": "application/json"})
46+
print(f" Status: {response.status_code}")
47+
48+
# Test 3: Authorization endpoint
49+
print("\n🔐 Testing authorization endpoint...")
50+
auth_params = {
51+
"response_type": "code",
52+
"client_id": "test-client-id",
53+
"redirect_uri": "http://localhost:3000/callback",
54+
"scope": "openid profile email",
55+
"state": "test-state-12345",
56+
"code_challenge": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
57+
"code_challenge_method": "S256"
58+
}
59+
response = requests.get(f"{base_url}/authorize",
60+
params=auth_params,
61+
allow_redirects=False)
62+
print(f" Status: {response.status_code}")
63+
64+
# Test 4: Token endpoint (this will show the most comprehensive logging)
65+
print("\n🎫 Testing token endpoint (will fail but show logging)...")
66+
token_data = {
67+
"grant_type": "authorization_code",
68+
"client_id": "test-client-id",
69+
"client_secret": "test-client-secret",
70+
"code": "test_auth_code_abcdef123456",
71+
"redirect_uri": "http://localhost:3000/callback",
72+
"code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
73+
}
74+
response = requests.post(f"{base_url}/token",
75+
data=token_data,
76+
headers={"Content-Type": "application/x-www-form-urlencoded"})
77+
print(f" Status: {response.status_code}")
78+
79+
# Test 5: Another token request with different parameters
80+
print("\n🎫 Testing token endpoint with different parameters...")
81+
token_data2 = {
82+
"grant_type": "authorization_code",
83+
"client_id": "test-client-id",
84+
"code": "different_test_code_xyz789",
85+
"redirect_uri": "http://localhost:3000/callback"
86+
}
87+
response = requests.post(f"{base_url}/token",
88+
data=token_data2,
89+
headers={"Content-Type": "application/x-www-form-urlencoded"})
90+
print(f" Status: {response.status_code}")
91+
92+
print("\n✅ Test requests completed!")
93+
print("🔍 Check the server logs above to see the enhanced logging with:")
94+
print(" • Correlation IDs for request tracing")
95+
print(" • Detailed request/response headers and data")
96+
print(" • Timing information")
97+
print(" • Emoji indicators for easy scanning")
98+
print(" • Sensitive data redaction")
99+
100+
except requests.exceptions.ConnectionError:
101+
print("❌ Could not connect to server. Make sure it's running on port 8000.")
102+
except Exception as e:
103+
print(f"❌ Error during testing: {e}")
104+
105+
if __name__ == "__main__":
106+
test_endpoints()

src/mcp/server/auth/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,45 @@
1+
# pyright: reportGeneralTypeIssues=false
12
"""
23
MCP OAuth server authorization components.
34
"""
5+
6+
# Convenience re-exports so users can simply::
7+
#
8+
# from mcp.server.auth import build_proxy_server
9+
#
10+
# instead of digging into the sub-package path.
11+
12+
from typing import TYPE_CHECKING
13+
14+
from mcp.server.auth.proxy import (
15+
configure_colored_logging,
16+
create_proxy_routes,
17+
fetch_upstream_metadata,
18+
)
19+
20+
# For *build_proxy_server* we need a lazy import to avoid a circular reference
21+
# during the initial package import sequence (FastMCP -> auth -> proxy ->
22+
# FastMCP ...). PEP 562 allows us to implement module-level `__getattr__` for
23+
# this purpose.
24+
25+
def __getattr__(name: str): # noqa: D401
26+
if name == "build_proxy_server":
27+
from mcp.server.auth.proxy.server import build_proxy_server as _bps # noqa: WPS433
28+
29+
globals()["build_proxy_server"] = _bps
30+
return _bps
31+
raise AttributeError(name)
32+
33+
# ---------------------------------------------------------------------------
34+
# Public API specification
35+
# ---------------------------------------------------------------------------
36+
37+
__all__: list[str] = [
38+
"configure_colored_logging",
39+
"create_proxy_routes",
40+
"fetch_upstream_metadata",
41+
"build_proxy_server",
42+
]
43+
44+
if TYPE_CHECKING: # pragma: no cover – make *build_proxy_server* visible to type checkers
45+
from mcp.server.auth.proxy.server import build_proxy_server # noqa: F401

0 commit comments

Comments
 (0)