Skip to content

Commit c491f37

Browse files
committed
Support falling back to OIDC metadata for auth
1 parent b8cb367 commit c491f37

File tree

2 files changed

+252
-60
lines changed

2 files changed

+252
-60
lines changed

src/mcp/client/auth.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
175175
return protocol_version >= "2025-06-18"
176176

177177

178+
OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]]
179+
180+
178181
class OAuthClientProvider(httpx.Auth):
179182
"""
180183
OAuth2 authentication for httpx.
@@ -221,32 +224,60 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
221224
except ValidationError:
222225
pass
223226

224-
def _build_well_known_path(self, pathname: str) -> str:
227+
def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str:
225228
"""Construct well-known path for OAuth metadata discovery."""
226-
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
229+
well_known_path = f"/.well-known/{well_known_endpoint}{pathname}"
227230
if pathname.endswith("/"):
228231
# Strip trailing slash from pathname to avoid double slashes
229232
well_known_path = well_known_path[:-1]
230233
return well_known_path
231234

232-
def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
233-
"""Determine if fallback to root discovery should be attempted."""
234-
return response_status == 404 and pathname != "/"
235+
def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str:
236+
"""Construct fallback well-known URL for OAuth metadata discovery in legacy servers."""
237+
base_url = getattr(self.context, "discovery_base_url", "")
238+
if not base_url:
239+
raise OAuthFlowError("No base URL available for fallback discovery")
240+
241+
# Fallback to root discovery for legacy servers
242+
return urljoin(base_url, f"/.well-known/{well_known_endpoint}")
243+
244+
def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str:
245+
"""Construct fallback well-known path for OIDC metadata discovery in legacy servers."""
246+
# Strip trailing slash from pathname to avoid double slashes
247+
clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname
248+
# OIDC 1.0 appends the well-known path to the full AS URL
249+
return f"{clean_pathname}/.well-known/{well_known_endpoint}"
250+
251+
def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str:
252+
"""Construct fallback well-known URL for OIDC metadata discovery in legacy servers."""
253+
if self.context.auth_server_url:
254+
auth_server_url = self.context.auth_server_url
255+
else:
256+
auth_server_url = self.context.server_url
257+
258+
parsed = urlparse(auth_server_url)
259+
well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint)
260+
base_url = f"{parsed.scheme}://{parsed.netloc}"
261+
return urljoin(base_url, well_known_path)
262+
263+
def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool:
264+
"""Determine if further fallback should be attempted."""
265+
return response_status == 404 and len(discovery_stack) > 0
235266

236267
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
237268
"""Build metadata discovery request for a specific URL."""
238269
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
239270

240-
async def _discover_oauth_metadata(self) -> httpx.Request:
241-
"""Build OAuth metadata discovery request with fallback support."""
271+
async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request:
272+
"""Build .well-known metadata discovery request with fallback support."""
242273
if self.context.auth_server_url:
243274
auth_server_url = self.context.auth_server_url
244275
else:
245276
auth_server_url = self.context.server_url
246277

247278
# Per RFC 8414, try path-aware discovery first
248279
parsed = urlparse(auth_server_url)
249-
well_known_path = self._build_well_known_path(parsed.path)
280+
well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint)
250281
base_url = f"{parsed.scheme}://{parsed.netloc}"
251282
url = urljoin(base_url, well_known_path)
252283

@@ -256,17 +287,37 @@ async def _discover_oauth_metadata(self) -> httpx.Request:
256287

257288
return await self._try_metadata_discovery(url)
258289

290+
async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request:
291+
"""Build fallback OAuth metadata discovery request for legacy servers."""
292+
url = self._build_well_known_fallback_url(well_known_endpoint)
293+
return await self._try_metadata_discovery(url)
294+
295+
async def _discover_oauth_metadata(self) -> httpx.Request:
296+
"""Build OAuth metadata discovery request with fallback support."""
297+
return await self._discover_well_known_metadata("oauth-authorization-server")
298+
259299
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
260300
"""Build fallback OAuth metadata discovery request for legacy servers."""
261-
base_url = getattr(self.context, "discovery_base_url", "")
262-
if not base_url:
263-
raise OAuthFlowError("No base URL available for fallback discovery")
301+
return await self._discover_well_known_metadata_fallback("oauth-authorization-server")
264302

265-
# Fallback to root discovery for legacy servers
266-
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
303+
async def _discover_oidc_metadata(self) -> httpx.Request:
304+
"""
305+
Build fallback OIDC metadata discovery request.
306+
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
307+
"""
308+
return await self._discover_well_known_metadata("openid-configuration")
309+
310+
async def _discover_oidc_metadata_fallback(self) -> httpx.Request:
311+
"""
312+
Build fallback OIDC metadata discovery request for legacy servers.
313+
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
314+
"""
315+
url = self._build_oidc_fallback_url("openid-configuration")
267316
return await self._try_metadata_discovery(url)
268317

269-
async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
318+
async def _handle_oauth_metadata_response(
319+
self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack
320+
) -> bool:
270321
"""Handle OAuth metadata response. Returns True if handled successfully."""
271322
if response.status_code == 200:
272323
try:
@@ -280,13 +331,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal
280331
except ValidationError:
281332
pass
282333

283-
# Check if we should attempt fallback (404 on path-aware discovery)
284-
if not is_fallback and self._should_attempt_fallback(
285-
response.status_code, getattr(self.context, "discovery_pathname", "/")
286-
):
287-
return False # Signal that fallback should be attempted
288-
289-
return True # Signal no fallback needed (either success or non-404 error)
334+
# Check if we should attempt fallback
335+
# True: No fallback needed (either success or non-404 error)
336+
# False: Signal that fallback should be attempted
337+
return not self._should_attempt_fallback(response.status_code, discovery_stack)
290338

291339
async def _register_client(self) -> httpx.Request | None:
292340
"""Build registration request or skip if already registered."""
@@ -481,6 +529,26 @@ def _add_auth_header(self, request: httpx.Request) -> None:
481529
if self.context.current_tokens and self.context.current_tokens.access_token:
482530
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
483531

532+
def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack:
533+
"""Create a stack of attempts to discover OAuth metadata."""
534+
discovery_attempts: OAuthDiscoveryStack = [
535+
# Start with path-aware OAuth discovery
536+
self._discover_oauth_metadata,
537+
# If path-aware discovery fails with 404, try fallback to root
538+
self._discover_oauth_metadata_fallback,
539+
# If root discovery fails with 404, fall back to OIDC 1.0 following
540+
# RFC 8414 path-aware semantics (see RFC 8414 section 5)
541+
self._discover_oidc_metadata,
542+
# If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0
543+
# following OIDC 1.0 semantics (see RFC 8414 section 5)
544+
self._discover_oidc_metadata_fallback,
545+
]
546+
547+
# Reverse the list so we can call pop() without remembering we declared
548+
# this stack backwards for readability
549+
discovery_attempts.reverse()
550+
return discovery_attempts
551+
484552
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
485553
"""HTTPX auth flow integration."""
486554
async with self.context.lock:
@@ -500,15 +568,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
500568
await self._handle_protected_resource_response(discovery_response)
501569

502570
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
503-
oauth_request = await self._discover_oauth_metadata()
504-
oauth_response = yield oauth_request
505-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
506-
507-
# If path-aware discovery failed with 404, try fallback to root
508-
if not handled:
509-
fallback_request = await self._discover_oauth_metadata_fallback()
510-
fallback_response = yield fallback_request
511-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
571+
oauth_discovery_stack = self._create_oauth_discovery_stack()
572+
while len(oauth_discovery_stack) > 0:
573+
oauth_discovery = oauth_discovery_stack.pop()
574+
oauth_request = await oauth_discovery()
575+
oauth_response = yield oauth_request
576+
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)
512577

513578
# Step 3: Register client if needed
514579
registration_request = await self._register_client()
@@ -552,15 +617,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
552617
await self._handle_protected_resource_response(discovery_response)
553618

554619
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
555-
oauth_request = await self._discover_oauth_metadata()
556-
oauth_response = yield oauth_request
557-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
558-
559-
# If path-aware discovery failed with 404, try fallback to root
560-
if not handled:
561-
fallback_request = await self._discover_oauth_metadata_fallback()
562-
fallback_response = yield fallback_request
563-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
620+
oauth_discovery_stack = self._create_oauth_discovery_stack()
621+
while len(oauth_discovery_stack) > 0:
622+
oauth_discovery = oauth_discovery_stack.pop()
623+
oauth_request = await oauth_discovery()
624+
oauth_response = yield oauth_request
625+
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)
564626

565627
# Step 3: Register client if needed
566628
registration_request = await self._register_client()

0 commit comments

Comments
 (0)