2020from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
2121from typing_extensions import deprecated
2222
23- from mcp .shared ._httpx_utils import (
24- MCP_DEFAULT_SSE_READ_TIMEOUT ,
25- MCP_DEFAULT_TIMEOUT ,
26- McpHttpClientFactory ,
27- create_mcp_http_client ,
28- )
23+ from mcp .shared ._httpx_utils import McpHttpClientFactory , create_mcp_http_client
2924from mcp .shared .message import ClientMessageMetadata , SessionMessage
3025from mcp .types import (
3126 ErrorData ,
@@ -70,52 +65,32 @@ class RequestContext:
7065 """Context for a request operation."""
7166
7267 client : httpx .AsyncClient
73- headers : dict [str , str ]
7468 session_id : str | None
7569 session_message : SessionMessage
7670 metadata : ClientMessageMetadata | None
7771 read_stream_writer : StreamWriter
78- sse_read_timeout : float
7972
8073
8174class StreamableHTTPTransport :
8275 """StreamableHTTP client transport implementation."""
8376
84- def __init__ (
85- self ,
86- url : str ,
87- headers : dict [str , str ] | None = None ,
88- timeout : float | timedelta = 30 ,
89- sse_read_timeout : float | timedelta = 60 * 5 ,
90- auth : httpx .Auth | None = None ,
91- ) -> None :
77+ def __init__ (self , url : str ) -> None :
9278 """Initialize the StreamableHTTP transport.
9379
9480 Args:
9581 url: The endpoint URL.
96- headers: Optional headers to include in requests.
97- timeout: HTTP timeout for regular operations.
98- sse_read_timeout: Timeout for SSE read operations.
99- auth: Optional HTTPX authentication handler.
10082 """
10183 self .url = url
102- self .headers = headers or {}
103- self .timeout = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
104- self .sse_read_timeout = (
105- sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
106- )
107- self .auth = auth
10884 self .session_id = None
10985 self .protocol_version = None
110- self .request_headers = {
111- ** self .headers ,
112- ACCEPT : f"{ JSON } , { SSE } " ,
113- CONTENT_TYPE : JSON ,
114- }
115-
116- def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
117- """Update headers with session ID and protocol version if available."""
118- headers = base_headers .copy ()
86+
87+ def _prepare_headers (self , client : httpx .AsyncClient ) -> dict [str , str ]:
88+ """Build request headers by merging client headers with MCP protocol and session headers."""
89+ headers = dict (client .headers ) if client .headers else {}
90+ # Add MCP protocol headers
91+ headers [ACCEPT ] = f"{ JSON } , { SSE } "
92+ headers [CONTENT_TYPE ] = JSON
93+ # Add session headers if available
11994 if self .session_id :
12095 headers [MCP_SESSION_ID ] = self .session_id
12196 if self .protocol_version :
@@ -206,14 +181,13 @@ async def handle_get_stream(
206181 if not self .session_id :
207182 return
208183
209- headers = self ._prepare_request_headers ( self . request_headers )
184+ headers = self ._prepare_headers ( client )
210185
211186 async with aconnect_sse (
212187 client ,
213188 "GET" ,
214189 self .url ,
215190 headers = headers ,
216- timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
217191 ) as event_source :
218192 event_source .response .raise_for_status ()
219193 logger .debug ("GET SSE connection established" )
@@ -226,7 +200,7 @@ async def handle_get_stream(
226200
227201 async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
228202 """Handle a resumption request using GET with SSE."""
229- headers = self ._prepare_request_headers (ctx .headers )
203+ headers = self ._prepare_headers (ctx .client )
230204 if ctx .metadata and ctx .metadata .resumption_token :
231205 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
232206 else :
@@ -242,7 +216,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
242216 "GET" ,
243217 self .url ,
244218 headers = headers ,
245- timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
246219 ) as event_source :
247220 event_source .response .raise_for_status ()
248221 logger .debug ("Resumption GET SSE connection established" )
@@ -260,7 +233,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
260233
261234 async def _handle_post_request (self , ctx : RequestContext ) -> None :
262235 """Handle a POST request with response processing."""
263- headers = self ._prepare_request_headers (ctx .headers )
236+ headers = self ._prepare_headers (ctx .client )
264237 message = ctx .session_message .message
265238 is_initialization = self ._is_initialization_request (message )
266239
@@ -401,12 +374,10 @@ async def post_writer(
401374
402375 ctx = RequestContext (
403376 client = client ,
404- headers = self .request_headers ,
405377 session_id = self .session_id ,
406378 session_message = session_message ,
407379 metadata = metadata ,
408380 read_stream_writer = read_stream_writer ,
409- sse_read_timeout = self .sse_read_timeout ,
410381 )
411382
412383 async def handle_request_async ():
@@ -433,7 +404,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
433404 return
434405
435406 try :
436- headers = self ._prepare_request_headers ( self . request_headers )
407+ headers = self ._prepare_headers ( client )
437408 response = await client .delete (self .url , headers = headers )
438409
439410 if response .status_code == 405 :
@@ -493,16 +464,8 @@ async def streamable_http_client(
493464 # Create default client with recommended MCP timeouts
494465 client = create_mcp_http_client ()
495466
496- # Extract configuration from the client to pass to transport
497- headers_dict = dict (client .headers ) if client .headers else None
498- timeout = client .timeout .connect if (client .timeout and client .timeout .connect is not None ) else MCP_DEFAULT_TIMEOUT
499- sse_read_timeout = (
500- client .timeout .read if (client .timeout and client .timeout .read is not None ) else MCP_DEFAULT_SSE_READ_TIMEOUT
501- )
502- auth = client .auth
503-
504- # Create transport with extracted configuration
505- transport = StreamableHTTPTransport (url , headers_dict , timeout , sse_read_timeout , auth )
467+ # Create transport
468+ transport = StreamableHTTPTransport (url )
506469
507470 async with anyio .create_task_group () as tg :
508471 try :
0 commit comments