Skip to content
Open
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ classifiers = [
dependencies = [
"anyio>=4.5",
"httpx>=0.27",
"httpx-sse>=0.4",
"httpx-sse>=0.4.1",
"pydantic>=2.7.2,<3.0.0",
"starlette>=0.27",
"python-multipart>=0.0.9",
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"typing_extensions>=4.12",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"jsonschema>=4.20.0",
]
Expand All @@ -49,10 +50,10 @@ required-version = ">=0.7.2"

[dependency-groups]
dev = [
"anyio[trio]",
"pyright>=1.1.391",
"pytest>=8.3.4",
"ruff>=0.8.5",
"trio>=0.26.2",
"pytest-flakefinder>=1.1.0",
"pytest-xdist>=3.6.1",
"pytest-examples>=0.0.14",
Expand Down Expand Up @@ -123,5 +124,5 @@ filterwarnings = [
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
]
66 changes: 33 additions & 33 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import aclosing, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta

Expand Down Expand Up @@ -240,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
async with aclosing(event_source.aiter_sse()) as iterator:
async for sse in iterator:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -319,18 +320,18 @@ async def _handle_sse_response(
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
async with aclosing(EventSource(response).aiter_sse()) as items:
async for sse in items:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
Expand Down Expand Up @@ -471,15 +472,14 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
async with create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
async with anyio.create_task_group() as tg:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
Expand All @@ -504,6 +504,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import Self

import mcp.types as types
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -93,10 +94,14 @@ def __init__(
)

self._init_options = init_options

async def __aenter__(self) -> Self:
await super().__aenter__()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose)
return self

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down
23 changes: 13 additions & 10 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from typing_extensions import Self
Expand Down Expand Up @@ -177,6 +178,8 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_exit_stack: AsyncExitStack
_task_group: TaskGroup

def __init__(
self,
Expand All @@ -196,12 +199,17 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(anyio.create_task_group())
self._task_group.start_soon(self._receive_loop)
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
exit_stack.callback(self._task_group.cancel_scope.cancel)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(
Expand All @@ -210,12 +218,7 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)

async def send_request(
self,
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ async def mock_server():
)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)

Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py

This file was deleted.

Loading
Loading