Skip to content

Commit 7f25ae9

Browse files
authored
Add ability for client to resume session (#1029)
1 parent 979befd commit 7f25ae9

File tree

9 files changed

+433
-3
lines changed

9 files changed

+433
-3
lines changed

src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can
9393

9494
private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
9595
{
96+
if (_options.KnownSessionId is not null)
97+
{
98+
throw new InvalidOperationException("Streamable HTTP transport is required to resume an existing session.");
99+
}
100+
96101
var sseTransport = new SseClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory);
97102

98103
try

src/ModelContextProtocol.Core/Client/HttpClientTransport.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ public HttpClientTransport(HttpClientTransportOptions transportOptions, HttpClie
7272
/// <inheritdoc />
7373
public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
7474
{
75+
if (_options.KnownSessionId is not null && _options.TransportMode == HttpTransportMode.Sse)
76+
{
77+
throw new InvalidOperationException("SSE transport does not support resuming an existing session.");
78+
}
79+
7580
return _options.TransportMode switch
7681
{
7782
HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory),

src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ public required Uri Endpoint
7373
/// </remarks>
7474
public IDictionary<string, string>? AdditionalHeaders { get; set; }
7575

76+
/// <summary>
77+
/// Gets or sets a session identifier that should be reused when connecting to a Streamable HTTP server.
78+
/// </summary>
79+
/// <remarks>
80+
/// <para>
81+
/// When non-<see langword="null"/>, the transport assumes the server already created the session and will include the
82+
/// specified session identifier in every HTTP request. This allows reconnecting to an existing session created in a
83+
/// previous process. This option is only supported by the Streamable HTTP transport mode.
84+
/// </para>
85+
/// <para>
86+
/// Clients should pair this with
87+
/// <see cref="McpClient.ResumeSessionAsync(IClientTransport, ResumeClientSessionOptions, McpClientOptions?, Microsoft.Extensions.Logging.ILoggerFactory?, CancellationToken)"/>
88+
/// to skip the initialization handshake when rehydrating a previously negotiated session.
89+
/// </para>
90+
/// </remarks>
91+
public string? KnownSessionId { get; set; }
92+
93+
/// <summary>
94+
/// Gets or sets a value indicating whether this transport endpoint is responsible for ending the session on dispose.
95+
/// </summary>
96+
/// <remarks>
97+
/// When <see langword="true"/> (default), the transport sends a DELETE request that informs the server the session is
98+
/// complete. Set this to <see langword="false"/> when creating a transport used solely to bootstrap session information
99+
/// that will later be resumed elsewhere.
100+
/// </remarks>
101+
public bool OwnsSession { get; set; } = true;
102+
76103
/// <summary>
77104
/// Gets sor sets the authorization provider to use for authentication.
78105
/// </summary>

src/ModelContextProtocol.Core/Client/McpClient.Methods.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ public static async Task<McpClient> CreateAsync(
5050
return clientSession;
5151
}
5252

53+
/// <summary>
54+
/// Recreates an <see cref="McpClient"/> using an existing transport session without sending a new initialize request.
55+
/// </summary>
56+
/// <param name="clientTransport">The transport instance already configured to connect to the target server.</param>
57+
/// <param name="resumeOptions">The metadata captured from the original session that should be applied when resuming.</param>
58+
/// <param name="clientOptions">Optional client settings that should mirror those used to create the original session.</param>
59+
/// <param name="loggerFactory">An optional logger factory for diagnostics.</param>
60+
/// <param name="cancellationToken">Token used when establishing the transport connection.</param>
61+
/// <returns>An <see cref="McpClient"/> bound to the resumed session.</returns>
62+
/// <exception cref="ArgumentNullException">Thrown when <paramref name="clientTransport"/> or <paramref name="resumeOptions"/> is <see langword="null"/>.</exception>
63+
public static async Task<McpClient> ResumeSessionAsync(
64+
IClientTransport clientTransport,
65+
ResumeClientSessionOptions resumeOptions,
66+
McpClientOptions? clientOptions = null,
67+
ILoggerFactory? loggerFactory = null,
68+
CancellationToken cancellationToken = default)
69+
{
70+
Throw.IfNull(clientTransport);
71+
Throw.IfNull(resumeOptions);
72+
Throw.IfNull(resumeOptions.ServerCapabilities);
73+
Throw.IfNull(resumeOptions.ServerInfo);
74+
75+
var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
76+
var endpointName = clientTransport.Name;
77+
78+
var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory);
79+
clientSession.ResumeSession(resumeOptions);
80+
return clientSession;
81+
}
82+
5383
/// <summary>
5484
/// Sends a ping request to verify server connectivity.
5585
/// </summary>

src/ModelContextProtocol.Core/Client/McpClientImpl.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not
8080
cancellationToken),
8181
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
8282
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
83-
83+
8484
_options.Capabilities ??= new();
8585
_options.Capabilities.Sampling ??= new();
8686
}
@@ -207,6 +207,28 @@ await this.SendNotificationAsync(
207207
LogClientConnected(_endpointName);
208208
}
209209

210+
/// <summary>
211+
/// Configures the client to use an already initialized session without performing the handshake.
212+
/// </summary>
213+
/// <param name="resumeOptions">The metadata captured from the previous session that should be applied to the resumed client.</param>
214+
internal void ResumeSession(ResumeClientSessionOptions resumeOptions)
215+
{
216+
Throw.IfNull(resumeOptions);
217+
Throw.IfNull(resumeOptions.ServerCapabilities);
218+
Throw.IfNull(resumeOptions.ServerInfo);
219+
220+
_ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None);
221+
222+
_serverCapabilities = resumeOptions.ServerCapabilities;
223+
_serverInfo = resumeOptions.ServerInfo;
224+
_serverInstructions = resumeOptions.ServerInstructions;
225+
_negotiatedProtocolVersion = resumeOptions.NegotiatedProtocolVersion
226+
?? _options.ProtocolVersion
227+
?? McpSessionHandler.LatestProtocolVersion;
228+
229+
LogClientSessionResumed(_endpointName);
230+
}
231+
210232
/// <inheritdoc/>
211233
public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
212234
=> _sessionHandler.SendRequestAsync(request, cancellationToken);
@@ -249,4 +271,7 @@ public override async ValueTask DisposeAsync()
249271

250272
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")]
251273
private partial void LogClientConnected(string endpointName);
274+
275+
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client resumed existing session.")]
276+
private partial void LogClientSessionResumed(string endpointName);
252277
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using ModelContextProtocol.Protocol;
2+
3+
namespace ModelContextProtocol.Client;
4+
5+
/// <summary>
6+
/// Provides the metadata captured from a previous MCP client session that is required to resume it.
7+
/// </summary>
8+
public sealed class ResumeClientSessionOptions
9+
{
10+
/// <summary>
11+
/// Gets or sets the server capabilities that were negotiated during the original session setialization.
12+
/// </summary>
13+
public required ServerCapabilities ServerCapabilities { get; set; }
14+
15+
/// <summary>
16+
/// Gets or sets the server implementation metadata that identifies the connected MCP server.
17+
/// </summary>
18+
public required Implementation ServerInfo { get; set; }
19+
20+
/// <summary>
21+
/// Gets or sets any instructions previously supplied by the server.
22+
/// </summary>
23+
public string? ServerInstructions { get; set; }
24+
25+
/// <summary>
26+
/// Gets or sets the protocol version that was negotiated with the server.
27+
/// </summary>
28+
public string? NegotiatedProtocolVersion { get; set; }
29+
}

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ public StreamableHttpClientSessionTransport(
4747
// until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync
4848
// so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user.
4949
SetConnected();
50+
51+
if (_options.KnownSessionId is { } knownSessionId)
52+
{
53+
SessionId = knownSessionId;
54+
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
55+
}
5056
}
5157

5258
/// <inheritdoc/>
@@ -60,6 +66,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
6066
// This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception.
6167
internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken)
6268
{
69+
if (_options.KnownSessionId is not null &&
70+
message is JsonRpcRequest { Method: RequestMethods.Initialize })
71+
{
72+
throw new InvalidOperationException(
73+
$"Cannot send '{RequestMethods.Initialize}' when {nameof(HttpClientTransportOptions)}.{nameof(HttpClientTransportOptions.KnownSessionId)} is configured. " +
74+
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
75+
}
76+
6377
using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
6478
cancellationToken = sendCts.Token;
6579

@@ -116,7 +130,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
116130
var initializeResult = JsonSerializer.Deserialize(initResponse.Result, McpJsonUtilities.JsonContext.Default.InitializeResult);
117131
_negotiatedProtocolVersion = initializeResult?.ProtocolVersion;
118132

119-
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
133+
_getReceiveTask ??= ReceiveUnsolicitedMessagesAsync();
120134
}
121135

122136
return response;
@@ -139,7 +153,7 @@ public override async ValueTask DisposeAsync()
139153
try
140154
{
141155
// Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec.
142-
if (!string.IsNullOrEmpty(SessionId))
156+
if (_options.OwnsSession && !string.IsNullOrEmpty(SessionId))
143157
{
144158
await SendDeleteRequest();
145159
}

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
using Microsoft.Extensions.DependencyInjection;
33
using Microsoft.Extensions.Primitives;
44
using ModelContextProtocol.Client;
5+
using ModelContextProtocol.Protocol;
6+
using ModelContextProtocol.Server;
57
using System.Collections.Concurrent;
8+
using System.Threading;
9+
using System.Threading.Tasks;
610

711
namespace ModelContextProtocol.AspNetCore.Tests;
812

@@ -188,4 +192,95 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia
188192
Assert.True(protocolVersionHeaderValues.Count > 1);
189193
Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v));
190194
}
195+
196+
[Fact]
197+
public async Task CanResumeSessionWithMapMcpAndRunSessionHandler()
198+
{
199+
Assert.SkipWhen(Stateless, "Session resumption relies on server-side session tracking.");
200+
201+
var runSessionCount = 0;
202+
var serverTcs = new TaskCompletionSource<McpServer>(TaskCreationOptions.RunContinuationsAsynchronously);
203+
204+
Builder.Services.AddMcpServer(options =>
205+
{
206+
options.ServerInfo = new Implementation
207+
{
208+
Name = "ResumeServer",
209+
Version = "1.0.0",
210+
};
211+
}).WithHttpTransport(opts =>
212+
{
213+
ConfigureStateless(opts);
214+
opts.RunSessionHandler = async (context, server, cancellationToken) =>
215+
{
216+
Interlocked.Increment(ref runSessionCount);
217+
serverTcs.TrySetResult(server);
218+
await server.RunAsync(cancellationToken);
219+
};
220+
}).WithTools<EchoHttpContextUserTools>();
221+
222+
await using var app = Builder.Build();
223+
app.MapMcp();
224+
await app.StartAsync(TestContext.Current.CancellationToken);
225+
226+
ServerCapabilities? serverCapabilities = null;
227+
Implementation? serverInfo = null;
228+
string? serverInstructions = null;
229+
string? negotiatedProtocolVersion = null;
230+
string? resumedSessionId = null;
231+
232+
await using var initialTransport = new HttpClientTransport(new()
233+
{
234+
Endpoint = new("http://localhost:5000/"),
235+
TransportMode = HttpTransportMode.StreamableHttp,
236+
OwnsSession = false,
237+
}, HttpClient, LoggerFactory);
238+
239+
await using (var initialClient = await McpClient.CreateAsync(initialTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken))
240+
{
241+
resumedSessionId = initialClient.SessionId ?? throw new InvalidOperationException("SessionId not negotiated.");
242+
serverCapabilities = initialClient.ServerCapabilities;
243+
serverInfo = initialClient.ServerInfo;
244+
serverInstructions = initialClient.ServerInstructions;
245+
negotiatedProtocolVersion = initialClient.NegotiatedProtocolVersion;
246+
247+
await initialClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
248+
}
249+
250+
Assert.NotNull(serverCapabilities);
251+
Assert.NotNull(serverInfo);
252+
Assert.False(string.IsNullOrEmpty(resumedSessionId));
253+
254+
await serverTcs.Task.WaitAsync(TimeSpan.FromSeconds(5), TestContext.Current.CancellationToken);
255+
256+
await using var resumeTransport = new HttpClientTransport(new()
257+
{
258+
Endpoint = new("http://localhost:5000/"),
259+
TransportMode = HttpTransportMode.StreamableHttp,
260+
KnownSessionId = resumedSessionId!,
261+
}, HttpClient, LoggerFactory);
262+
263+
var resumeOptions = new ResumeClientSessionOptions
264+
{
265+
ServerCapabilities = serverCapabilities!,
266+
ServerInfo = serverInfo!,
267+
ServerInstructions = serverInstructions,
268+
NegotiatedProtocolVersion = negotiatedProtocolVersion,
269+
};
270+
271+
await using (var resumedClient = await McpClient.ResumeSessionAsync(
272+
resumeTransport,
273+
resumeOptions,
274+
loggerFactory: LoggerFactory,
275+
cancellationToken: TestContext.Current.CancellationToken))
276+
{
277+
var tools = await resumedClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
278+
Assert.NotEmpty(tools);
279+
280+
Assert.Equal(serverInstructions, resumedClient.ServerInstructions);
281+
Assert.Equal(negotiatedProtocolVersion, resumedClient.NegotiatedProtocolVersion);
282+
}
283+
284+
Assert.Equal(1, runSessionCount);
285+
}
191286
}

0 commit comments

Comments
 (0)