Skip to content

Commit 90436c1

Browse files
committed
split resource server and auth server
Signed-off-by: Jesse Sanford <[email protected]>
1 parent e592c56 commit 90436c1

File tree

3 files changed

+565
-0
lines changed

3 files changed

+565
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# pyright: reportMissingImports=false
2+
import logging
3+
import os
4+
import time
5+
6+
from dotenv import load_dotenv # type: ignore
7+
from mcp.server.auth.provider import AccessToken, OAuthToken
8+
from mcp.server.auth.providers.transparent_proxy import (
9+
ProxySettings, # type: ignore
10+
TransparentOAuthProxyProvider,
11+
ProxyTokenHandler,
12+
)
13+
from mcp.server.auth.routes import cors_middleware, create_auth_routes
14+
from mcp.server.auth.settings import ClientRegistrationOptions
15+
from pydantic import AnyHttpUrl
16+
from starlette.applications import Starlette
17+
from starlette.requests import Request # type: ignore
18+
from starlette.responses import JSONResponse, Response
19+
from starlette.routing import Route
20+
from uvicorn import Config, Server
21+
22+
# Load environment variables from .env if present
23+
load_dotenv()
24+
25+
# Configure logging after .env so LOG_LEVEL can come from environment
26+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
27+
28+
logging.basicConfig(
29+
level=LOG_LEVEL,
30+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
31+
datefmt="%Y-%m-%d %H:%M:%S",
32+
)
33+
34+
# Dedicated logger for this server module
35+
logger = logging.getLogger("proxy_oauth.auth_server")
36+
37+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
38+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
39+
# ListToolsRequest") are helpful for debugging but clutter normal output.
40+
41+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
42+
if LOG_LEVEL == "DEBUG":
43+
# In full debug mode, allow the library to emit its detailed logs
44+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
45+
else:
46+
# Otherwise, only warnings and above
47+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
48+
49+
# ----------------------------------------------------------------------------
50+
# Environment configuration
51+
# ----------------------------------------------------------------------------
52+
# Load and validate settings from the environment (uses .env automatically)
53+
settings = ProxySettings.load()
54+
55+
# Upstream endpoints (fully-qualified URLs)
56+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
57+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
58+
UPSTREAM_JWKS_URI = settings.jwks_uri
59+
# Derive base URL from the authorize endpoint for convenience / tests
60+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
61+
62+
# Client credentials & defaults
63+
CLIENT_ID: str = settings.client_id or "demo-client-id"
64+
CLIENT_SECRET = settings.client_secret
65+
DEFAULT_SCOPE: str = settings.default_scope
66+
67+
# Optional audience passthrough (not part of ProxySettings yet)
68+
AUDIENCE = os.getenv("PROXY_AUDIENCE")
69+
70+
# Metadata URL (only used if we need to fetch from upstream)
71+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
72+
73+
# ---------------------------------------------------------------------------
74+
# Logging helpers
75+
# ---------------------------------------------------------------------------
76+
77+
78+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
79+
"""Return a masked version of the given secret.
80+
81+
The first and last four characters are preserved (if available) and the
82+
middle section is replaced by asterisks. If the secret is shorter than
83+
eight characters, the entire value is replaced by ``*``.
84+
"""
85+
86+
if not secret:
87+
return None
88+
89+
if len(secret) <= 8:
90+
return "*" * len(secret)
91+
92+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
93+
94+
95+
# Consolidated configuration (with sensitive data redacted)
96+
_masked_settings = settings.model_dump(exclude_none=True).copy()
97+
98+
if "client_secret" in _masked_settings:
99+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
100+
101+
# Log configuration at *debug* level only so it can be enabled when needed
102+
logger.debug("[Auth Proxy Config] %s", _masked_settings)
103+
104+
# Server host/port
105+
AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000"))
106+
AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost")
107+
AUTH_SERVER_URL = os.getenv(
108+
"AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}"
109+
)
110+
111+
# ----------------------------------------------------------------------------
112+
# Auth Server
113+
# ----------------------------------------------------------------------------
114+
115+
# Create auth provider
116+
oauth_provider = TransparentOAuthProxyProvider(settings=settings)
117+
118+
# Enable client registration
119+
client_registration_options = ClientRegistrationOptions(
120+
enabled=True,
121+
valid_scopes=["openid"],
122+
default_scopes=["openid"],
123+
)
124+
125+
# Create auth routes
126+
routes = create_auth_routes(
127+
provider=oauth_provider,
128+
issuer_url=AnyHttpUrl(AUTH_SERVER_URL),
129+
service_documentation_url=None,
130+
client_registration_options=client_registration_options,
131+
revocation_options=None,
132+
)
133+
134+
# Add token endpoint handler
135+
# We need to replace any existing token endpoint route
136+
routes = [r for r in routes if not (hasattr(r, "path") and r.path == "/token")]
137+
138+
# Create token handler and add it to routes
139+
proxy_token_handler = ProxyTokenHandler(oauth_provider)
140+
routes.append(Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"]))
141+
142+
# Add token introspection endpoint for Resource Servers
143+
async def introspect_handler(request: Request) -> Response:
144+
"""
145+
Token introspection endpoint for Resource Servers.
146+
147+
Resource Servers call this endpoint to validate tokens without
148+
needing direct access to token storage.
149+
"""
150+
form = await request.form()
151+
token = form.get("token")
152+
if not token or not isinstance(token, str):
153+
return JSONResponse({"active": False}, status_code=400)
154+
155+
# For the transparent proxy, we don't actually validate tokens
156+
# Just create a dummy AccessToken like the provider does
157+
access_token = AccessToken(
158+
token=token, client_id=str(CLIENT_ID), scopes=[DEFAULT_SCOPE], expires_at=None
159+
)
160+
161+
return JSONResponse(
162+
{
163+
"active": True,
164+
"client_id": access_token.client_id,
165+
"scope": " ".join(access_token.scopes),
166+
"exp": access_token.expires_at,
167+
"iat": int(time.time()),
168+
"token_type": "Bearer",
169+
"aud": access_token.resource, # RFC 8707 audience claim
170+
}
171+
)
172+
173+
174+
routes.append(
175+
Route(
176+
"/introspect",
177+
endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]),
178+
methods=["POST", "OPTIONS"],
179+
)
180+
)
181+
182+
# Create Starlette app with routes
183+
auth_app = Starlette(routes=routes)
184+
185+
186+
async def run_server():
187+
"""Run the Authorization Server."""
188+
config = Config(
189+
auth_app,
190+
host=AUTH_SERVER_HOST,
191+
port=AUTH_SERVER_PORT,
192+
log_level="info",
193+
)
194+
server = Server(config)
195+
196+
logger.info(f"🚀 MCP Authorization Server running on {AUTH_SERVER_URL}")
197+
198+
await server.serve()
199+
200+
201+
if __name__ == "__main__":
202+
import asyncio
203+
204+
asyncio.run(run_server())

0 commit comments

Comments
 (0)