22
22
from mcp .shared .message import ClientMessageMetadata , SessionMessage
23
23
from mcp .types import (
24
24
ErrorData ,
25
+ InitializeResult ,
25
26
JSONRPCError ,
26
27
JSONRPCMessage ,
27
28
JSONRPCNotification ,
39
40
GetSessionIdCallback = Callable [[], str | None ]
40
41
41
42
MCP_SESSION_ID = "mcp-session-id"
43
+ MCP_PROTOCOL_VERSION = "mcp-protocol-version"
42
44
LAST_EVENT_ID = "last-event-id"
43
45
CONTENT_TYPE = "content-type"
44
46
ACCEPT = "Accept"
@@ -97,17 +99,20 @@ def __init__(
97
99
)
98
100
self .auth = auth
99
101
self .session_id = None
102
+ self .protocol_version = None
100
103
self .request_headers = {
101
104
ACCEPT : f"{ JSON } , { SSE } " ,
102
105
CONTENT_TYPE : JSON ,
103
106
** self .headers ,
104
107
}
105
108
106
- def _update_headers_with_session (self , base_headers : dict [str , str ]) -> dict [str , str ]:
107
- """Update headers with session ID if available."""
109
+ def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
110
+ """Update headers with session ID and protocol version if available."""
108
111
headers = base_headers .copy ()
109
112
if self .session_id :
110
113
headers [MCP_SESSION_ID ] = self .session_id
114
+ if self .protocol_version :
115
+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
111
116
return headers
112
117
113
118
def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -128,19 +133,39 @@ def _maybe_extract_session_id_from_response(
128
133
self .session_id = new_session_id
129
134
logger .info (f"Received session ID: { self .session_id } " )
130
135
136
+ def _maybe_extract_protocol_version_from_message (
137
+ self ,
138
+ message : JSONRPCMessage ,
139
+ ) -> None :
140
+ """Extract protocol version from initialization response message."""
141
+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
142
+ try :
143
+ # Parse the result as InitializeResult for type safety
144
+ init_result = InitializeResult .model_validate (message .root .result )
145
+ self .protocol_version = str (init_result .protocolVersion )
146
+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
147
+ except Exception as exc :
148
+ logger .warning (f"Failed to parse initialization response as InitializeResult: { exc } " )
149
+ logger .warning (f"Raw result: { message .root .result } " )
150
+
131
151
async def _handle_sse_event (
132
152
self ,
133
153
sse : ServerSentEvent ,
134
154
read_stream_writer : StreamWriter ,
135
155
original_request_id : RequestId | None = None ,
136
156
resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
157
+ is_initialization : bool = False ,
137
158
) -> bool :
138
159
"""Handle an SSE event, returning True if the response is complete."""
139
160
if sse .event == "message" :
140
161
try :
141
162
message = JSONRPCMessage .model_validate_json (sse .data )
142
163
logger .debug (f"SSE message: { message } " )
143
164
165
+ # Extract protocol version from initialization response
166
+ if is_initialization :
167
+ self ._maybe_extract_protocol_version_from_message (message )
168
+
144
169
# If this is a response and we have original_request_id, replace it
145
170
if original_request_id is not None and isinstance (message .root , JSONRPCResponse | JSONRPCError ):
146
171
message .root .id = original_request_id
@@ -174,7 +199,7 @@ async def handle_get_stream(
174
199
if not self .session_id :
175
200
return
176
201
177
- headers = self ._update_headers_with_session (self .request_headers )
202
+ headers = self ._prepare_request_headers (self .request_headers )
178
203
179
204
async with aconnect_sse (
180
205
client ,
@@ -194,7 +219,7 @@ async def handle_get_stream(
194
219
195
220
async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
196
221
"""Handle a resumption request using GET with SSE."""
197
- headers = self ._update_headers_with_session (ctx .headers )
222
+ headers = self ._prepare_request_headers (ctx .headers )
198
223
if ctx .metadata and ctx .metadata .resumption_token :
199
224
headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
200
225
else :
@@ -227,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
227
252
228
253
async def _handle_post_request (self , ctx : RequestContext ) -> None :
229
254
"""Handle a POST request with response processing."""
230
- headers = self ._update_headers_with_session (ctx .headers )
255
+ headers = self ._prepare_request_headers (ctx .headers )
231
256
message = ctx .session_message .message
232
257
is_initialization = self ._is_initialization_request (message )
233
258
@@ -256,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256
281
content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
257
282
258
283
if content_type .startswith (JSON ):
259
- await self ._handle_json_response (response , ctx .read_stream_writer )
284
+ await self ._handle_json_response (response , ctx .read_stream_writer , is_initialization )
260
285
elif content_type .startswith (SSE ):
261
- await self ._handle_sse_response (response , ctx )
286
+ await self ._handle_sse_response (response , ctx , is_initialization )
262
287
else :
263
288
await self ._handle_unexpected_content_type (
264
289
content_type ,
@@ -269,18 +294,29 @@ async def _handle_json_response(
269
294
self ,
270
295
response : httpx .Response ,
271
296
read_stream_writer : StreamWriter ,
297
+ is_initialization : bool = False ,
272
298
) -> None :
273
299
"""Handle JSON response from the server."""
274
300
try :
275
301
content = await response .aread ()
276
302
message = JSONRPCMessage .model_validate_json (content )
303
+
304
+ # Extract protocol version from initialization response
305
+ if is_initialization :
306
+ self ._maybe_extract_protocol_version_from_message (message )
307
+
277
308
session_message = SessionMessage (message )
278
309
await read_stream_writer .send (session_message )
279
310
except Exception as exc :
280
311
logger .error (f"Error parsing JSON response: { exc } " )
281
312
await read_stream_writer .send (exc )
282
313
283
- async def _handle_sse_response (self , response : httpx .Response , ctx : RequestContext ) -> None :
314
+ async def _handle_sse_response (
315
+ self ,
316
+ response : httpx .Response ,
317
+ ctx : RequestContext ,
318
+ is_initialization : bool = False ,
319
+ ) -> None :
284
320
"""Handle SSE response from the server."""
285
321
try :
286
322
event_source = EventSource (response )
@@ -289,6 +325,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
289
325
sse ,
290
326
ctx .read_stream_writer ,
291
327
resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
328
+ is_initialization = is_initialization ,
292
329
)
293
330
# If the SSE event indicates completion, like returning respose/error
294
331
# break the loop
@@ -385,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
385
422
return
386
423
387
424
try :
388
- headers = self ._update_headers_with_session (self .request_headers )
425
+ headers = self ._prepare_request_headers (self .request_headers )
389
426
response = await client .delete (self .url , headers = headers )
390
427
391
428
if response .status_code == 405 :
0 commit comments