66and session management.
77"""
88
9+ import contextlib
910import logging
1011from collections .abc import AsyncGenerator , Awaitable , Callable
1112from contextlib import asynccontextmanager
1920from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
2021from typing_extensions import deprecated
2122
22- from mcp .shared ._httpx_utils import McpHttpClientFactory , create_mcp_http_client
23+ from mcp .shared ._httpx_utils import (
24+ MCP_DEFAULT_SSE_READ_TIMEOUT ,
25+ MCP_DEFAULT_TIMEOUT ,
26+ McpHttpClientFactory ,
27+ create_mcp_http_client ,
28+ )
2329from mcp .shared .message import ClientMessageMetadata , SessionMessage
2430from mcp .types import (
2531 ErrorData ,
@@ -102,9 +108,9 @@ def __init__(
102108 self .session_id = None
103109 self .protocol_version = None
104110 self .request_headers = {
111+ ** self .headers ,
105112 ACCEPT : f"{ JSON } , { SSE } " ,
106113 CONTENT_TYPE : JSON ,
107- ** self .headers ,
108114 }
109115
110116 def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
@@ -445,12 +451,9 @@ def get_session_id(self) -> str | None:
445451@asynccontextmanager
446452async def streamable_http_client (
447453 url : str ,
448- headers : dict [str , str ] | None = None ,
449- timeout : float | timedelta = 30 ,
450- sse_read_timeout : float | timedelta = 60 * 5 ,
454+ * ,
455+ httpx_client : httpx .AsyncClient | None = None ,
451456 terminate_on_close : bool = True ,
452- httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
453- auth : httpx .Auth | None = None ,
454457) -> AsyncGenerator [
455458 tuple [
456459 MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -462,30 +465,57 @@ async def streamable_http_client(
462465 """
463466 Client transport for StreamableHTTP.
464467
465- `sse_read_timeout` determines how long (in seconds) the client will wait for a new
466- event before disconnecting. All other HTTP operations are controlled by `timeout`.
468+ Args:
469+ url: The MCP server endpoint URL.
470+ httpx_client: Optional pre-configured httpx.AsyncClient. If None, a default
471+ client with recommended MCP timeouts will be created. To configure headers,
472+ authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
473+ terminate_on_close: If True, send a DELETE request to terminate the session
474+ when the context exits.
467475
468476 Yields:
469477 Tuple containing:
470478 - read_stream: Stream for reading messages from the server
471479 - write_stream: Stream for sending messages to the server
472480 - get_session_id_callback: Function to retrieve the current session ID
473- """
474- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout , auth )
475481
482+ Example:
483+ See examples/snippets/clients/ for usage patterns.
484+ """
476485 read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
477486 write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
478487
488+ # Determine if we need to create and manage the client
489+ client_provided = httpx_client is not None
490+ client = httpx_client
491+
492+ if client is None :
493+ # Create default client with recommended MCP timeouts
494+ client = create_mcp_http_client ()
495+
496+ # Extract configuration from the client to pass to transport
497+ headers_dict = dict (client .headers ) if client .headers else None
498+ timeout = client .timeout .connect if (client .timeout and client .timeout .connect is not None ) else MCP_DEFAULT_TIMEOUT
499+ sse_read_timeout = (
500+ client .timeout .read if (client .timeout and client .timeout .read is not None ) else MCP_DEFAULT_SSE_READ_TIMEOUT
501+ )
502+ auth = client .auth
503+
504+ # Create transport with extracted configuration
505+ transport = StreamableHTTPTransport (url , headers_dict , timeout , sse_read_timeout , auth )
506+
507+ # Sync client headers with transport's merged headers (includes MCP protocol requirements)
508+ client .headers .update (transport .request_headers )
509+
479510 async with anyio .create_task_group () as tg :
480511 try :
481512 logger .debug (f"Connecting to StreamableHTTP endpoint: { url } " )
482513
483- async with httpx_client_factory (
484- headers = transport .request_headers ,
485- timeout = httpx .Timeout (transport .timeout , read = transport .sse_read_timeout ),
486- auth = transport .auth ,
487- ) as client :
488- # Define callbacks that need access to tg
514+ async with contextlib .AsyncExitStack () as stack :
515+ # Only manage client lifecycle if we created it
516+ if not client_provided :
517+ await stack .enter_async_context (client )
518+
489519 def start_get_stream () -> None :
490520 tg .start_soon (transport .handle_get_stream , client , read_stream_writer )
491521
@@ -532,7 +562,24 @@ async def streamablehttp_client(
532562 ],
533563 None ,
534564]:
535- async with streamable_http_client (
536- url , headers , timeout , sse_read_timeout , terminate_on_close , httpx_client_factory , auth
537- ) as streams :
538- yield streams
565+ # Convert timeout parameters
566+ timeout_seconds = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
567+ sse_read_timeout_seconds = (
568+ sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
569+ )
570+
571+ # Create httpx client using the factory with old-style parameters
572+ client = httpx_client_factory (
573+ headers = headers ,
574+ timeout = httpx .Timeout (timeout_seconds , read = sse_read_timeout_seconds ),
575+ auth = auth ,
576+ )
577+
578+ # Manage client lifecycle since we created it
579+ async with client :
580+ async with streamable_http_client (
581+ url ,
582+ httpx_client = client ,
583+ terminate_on_close = terminate_on_close ,
584+ ) as streams :
585+ yield streams
0 commit comments