Skip to content

Commit 0d8595c

Browse files
authored
Merge branch 'main' into fix-test-with-snapshot
2 parents ec980dd + df15e95 commit 0d8595c

File tree

6 files changed

+268
-21
lines changed

6 files changed

+268
-21
lines changed

src/mcp/client/auth.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
import anyio
1818
import httpx
1919

20-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
20+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
21+
from mcp.shared.auth import (
22+
OAuthClientInformationFull,
23+
OAuthClientMetadata,
24+
OAuthMetadata,
25+
OAuthToken,
26+
)
2127
from mcp.types import LATEST_PROTOCOL_VERSION
2228

2329
logger = logging.getLogger(__name__)
@@ -121,7 +127,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
121127
# Extract base URL per MCP spec
122128
auth_base_url = self._get_authorization_base_url(server_url)
123129
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
124-
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
130+
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
125131

126132
async with httpx.AsyncClient() as client:
127133
try:

src/mcp/client/streamable_http.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2323
from mcp.types import (
2424
ErrorData,
25+
InitializeResult,
2526
JSONRPCError,
2627
JSONRPCMessage,
2728
JSONRPCNotification,
@@ -39,6 +40,7 @@
3940
GetSessionIdCallback = Callable[[], str | None]
4041

4142
MCP_SESSION_ID = "mcp-session-id"
43+
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
4244
LAST_EVENT_ID = "last-event-id"
4345
CONTENT_TYPE = "content-type"
4446
ACCEPT = "Accept"
@@ -97,17 +99,20 @@ def __init__(
9799
)
98100
self.auth = auth
99101
self.session_id = None
102+
self.protocol_version = None
100103
self.request_headers = {
101104
ACCEPT: f"{JSON}, {SSE}",
102105
CONTENT_TYPE: JSON,
103106
**self.headers,
104107
}
105108

106-
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
107-
"""Update headers with session ID if available."""
109+
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
110+
"""Update headers with session ID and protocol version if available."""
108111
headers = base_headers.copy()
109112
if self.session_id:
110113
headers[MCP_SESSION_ID] = self.session_id
114+
if self.protocol_version:
115+
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
111116
return headers
112117

113118
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -128,19 +133,39 @@ def _maybe_extract_session_id_from_response(
128133
self.session_id = new_session_id
129134
logger.info(f"Received session ID: {self.session_id}")
130135

136+
def _maybe_extract_protocol_version_from_message(
137+
self,
138+
message: JSONRPCMessage,
139+
) -> None:
140+
"""Extract protocol version from initialization response message."""
141+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
142+
try:
143+
# Parse the result as InitializeResult for type safety
144+
init_result = InitializeResult.model_validate(message.root.result)
145+
self.protocol_version = str(init_result.protocolVersion)
146+
logger.info(f"Negotiated protocol version: {self.protocol_version}")
147+
except Exception as exc:
148+
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
149+
logger.warning(f"Raw result: {message.root.result}")
150+
131151
async def _handle_sse_event(
132152
self,
133153
sse: ServerSentEvent,
134154
read_stream_writer: StreamWriter,
135155
original_request_id: RequestId | None = None,
136156
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
157+
is_initialization: bool = False,
137158
) -> bool:
138159
"""Handle an SSE event, returning True if the response is complete."""
139160
if sse.event == "message":
140161
try:
141162
message = JSONRPCMessage.model_validate_json(sse.data)
142163
logger.debug(f"SSE message: {message}")
143164

165+
# Extract protocol version from initialization response
166+
if is_initialization:
167+
self._maybe_extract_protocol_version_from_message(message)
168+
144169
# If this is a response and we have original_request_id, replace it
145170
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
146171
message.root.id = original_request_id
@@ -174,7 +199,7 @@ async def handle_get_stream(
174199
if not self.session_id:
175200
return
176201

177-
headers = self._update_headers_with_session(self.request_headers)
202+
headers = self._prepare_request_headers(self.request_headers)
178203

179204
async with aconnect_sse(
180205
client,
@@ -194,7 +219,7 @@ async def handle_get_stream(
194219

195220
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
196221
"""Handle a resumption request using GET with SSE."""
197-
headers = self._update_headers_with_session(ctx.headers)
222+
headers = self._prepare_request_headers(ctx.headers)
198223
if ctx.metadata and ctx.metadata.resumption_token:
199224
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
200225
else:
@@ -227,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
227252

228253
async def _handle_post_request(self, ctx: RequestContext) -> None:
229254
"""Handle a POST request with response processing."""
230-
headers = self._update_headers_with_session(ctx.headers)
255+
headers = self._prepare_request_headers(ctx.headers)
231256
message = ctx.session_message.message
232257
is_initialization = self._is_initialization_request(message)
233258

@@ -256,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256281
content_type = response.headers.get(CONTENT_TYPE, "").lower()
257282

258283
if content_type.startswith(JSON):
259-
await self._handle_json_response(response, ctx.read_stream_writer)
284+
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
260285
elif content_type.startswith(SSE):
261-
await self._handle_sse_response(response, ctx)
286+
await self._handle_sse_response(response, ctx, is_initialization)
262287
else:
263288
await self._handle_unexpected_content_type(
264289
content_type,
@@ -269,18 +294,29 @@ async def _handle_json_response(
269294
self,
270295
response: httpx.Response,
271296
read_stream_writer: StreamWriter,
297+
is_initialization: bool = False,
272298
) -> None:
273299
"""Handle JSON response from the server."""
274300
try:
275301
content = await response.aread()
276302
message = JSONRPCMessage.model_validate_json(content)
303+
304+
# Extract protocol version from initialization response
305+
if is_initialization:
306+
self._maybe_extract_protocol_version_from_message(message)
307+
277308
session_message = SessionMessage(message)
278309
await read_stream_writer.send(session_message)
279310
except Exception as exc:
280311
logger.error(f"Error parsing JSON response: {exc}")
281312
await read_stream_writer.send(exc)
282313

283-
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
314+
async def _handle_sse_response(
315+
self,
316+
response: httpx.Response,
317+
ctx: RequestContext,
318+
is_initialization: bool = False,
319+
) -> None:
284320
"""Handle SSE response from the server."""
285321
try:
286322
event_source = EventSource(response)
@@ -289,6 +325,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
289325
sse,
290326
ctx.read_stream_writer,
291327
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
328+
is_initialization=is_initialization,
292329
)
293330
# If the SSE event indicates completion, like returning respose/error
294331
# break the loop
@@ -385,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
385422
return
386423

387424
try:
388-
headers = self._update_headers_with_session(self.request_headers)
425+
headers = self._prepare_request_headers(self.request_headers)
389426
response = await client.delete(self.url, headers=headers)
390427

391428
if response.status_code == 405:

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
1717
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
1818
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
19+
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
1920
from mcp.shared.auth import OAuthMetadata
2021

2122

@@ -55,7 +56,7 @@ def cors_middleware(
5556
app=request_response(handler),
5657
allow_origins="*",
5758
allow_methods=allow_methods,
58-
allow_headers=["mcp-protocol-version"],
59+
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
5960
)
6061
return cors_app
6162

src/mcp/server/streamable_http.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from starlette.types import Receive, Scope, Send
2626

2727
from mcp.shared.message import ServerMessageMetadata, SessionMessage
28+
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2829
from mcp.types import (
30+
DEFAULT_NEGOTIATED_VERSION,
2931
INTERNAL_ERROR,
3032
INVALID_PARAMS,
3133
INVALID_REQUEST,
@@ -45,6 +47,7 @@
4547

4648
# Header names
4749
MCP_SESSION_ID_HEADER = "mcp-session-id"
50+
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
4851
LAST_EVENT_ID_HEADER = "last-event-id"
4952

5053
# Content types
@@ -293,7 +296,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
293296
has_json, has_sse = self._check_accept_headers(request)
294297
if not (has_json and has_sse):
295298
response = self._create_error_response(
296-
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
299+
("Not Acceptable: Client must accept both application/json and text/event-stream"),
297300
HTTPStatus.NOT_ACCEPTABLE,
298301
)
299302
await response(scope, receive, send)
@@ -353,8 +356,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
353356
)
354357
await response(scope, receive, send)
355358
return
356-
# For non-initialization requests, validate the session
357-
elif not await self._validate_session(request, send):
359+
elif not await self._validate_request_headers(request, send):
358360
return
359361

360362
# For notifications and responses only, return 202 Accepted
@@ -513,8 +515,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
513515
await response(request.scope, request.receive, send)
514516
return
515517

516-
if not await self._validate_session(request, send):
518+
if not await self._validate_request_headers(request, send):
517519
return
520+
518521
# Handle resumability: check for Last-Event-ID header
519522
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
520523
await self._replay_events(last_event_id, request, send)
@@ -593,7 +596,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
593596
await response(request.scope, request.receive, send)
594597
return
595598

596-
if not await self._validate_session(request, send):
599+
if not await self._validate_request_headers(request, send):
597600
return
598601

599602
await self._terminate_session()
@@ -653,6 +656,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
653656
)
654657
await response(request.scope, request.receive, send)
655658

659+
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
660+
if not await self._validate_session(request, send):
661+
return False
662+
if not await self._validate_protocol_version(request, send):
663+
return False
664+
return True
665+
656666
async def _validate_session(self, request: Request, send: Send) -> bool:
657667
"""Validate the session ID in the request."""
658668
if not self.mcp_session_id:
@@ -682,6 +692,28 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
682692

683693
return True
684694

695+
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
696+
"""Validate the protocol version header in the request."""
697+
# Get the protocol version from the request headers
698+
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
699+
700+
# If no protocol version provided, assume default version
701+
if protocol_version is None:
702+
protocol_version = DEFAULT_NEGOTIATED_VERSION
703+
704+
# Check if the protocol version is supported
705+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
706+
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
707+
response = self._create_error_response(
708+
f"Bad Request: Unsupported protocol version: {protocol_version}. "
709+
+ f"Supported versions: {supported_versions}",
710+
HTTPStatus.BAD_REQUEST,
711+
)
712+
await response(request.scope, request.receive, send)
713+
return False
714+
715+
return True
716+
685717
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
686718
"""
687719
Replays events that would have been sent after the specified event ID.

src/mcp/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@
2424

2525
LATEST_PROTOCOL_VERSION = "2025-03-26"
2626

27+
"""
28+
The default negotiated version of the Model Context Protocol when no version is specified.
29+
We need this to satisfy the MCP specification, which requires the server to assume a
30+
specific version if none is provided by the client. See section "Protocol Version Header" at
31+
https://modelcontextprotocol.io/specification
32+
"""
33+
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
34+
2735
ProgressToken = str | int
2836
Cursor = str
2937
Role = Literal["user", "assistant"]

0 commit comments

Comments
 (0)