Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
<PackageVersion Include="Microsoft.Extensions.Logging" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.Logging.Console" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.Options" Version="9.0.4" />
<PackageVersion Include="Microsoft.Extensions.TimeProvider.Testing" Version="9.4.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageVersion Include="Moq" Version="4.20.72" />
<PackageVersion Include="OpenTelemetry" Version="1.11.2" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public static class HttpMcpServerBuilderExtensions
public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action<HttpServerTransportOptions>? configureOptions = null)
{
ArgumentNullException.ThrowIfNull(builder);

builder.Services.TryAddSingleton<StreamableHttpHandler>();
builder.Services.TryAddSingleton<SseHandler>();
builder.Services.AddHostedService<IdleTrackingBackgroundService>();
Expand Down
15 changes: 10 additions & 5 deletions src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ public async ValueTask DisposeAsync()
}
finally
{
if (Server is not null)
try
{
await Server.DisposeAsync();
if (Server is not null)
{
await Server.DisposeAsync();
}
}
finally
{
await Transport.DisposeAsync();
_disposeCts.Dispose();
}

await Transport.DisposeAsync();
_disposeCts.Dispose();
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@ public class HttpServerTransportOptions
/// Represents the duration of time the server will wait between any active requests before timing out an
/// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will
/// receive a 404 status code and should restart their session. A client can keep their session open by
/// keeping a GET request open. The default value is set to 2 minutes.
/// keeping a GET request open. The default value is set to 2 hours.
/// </summary>
public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2);
public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromHours(2);

/// <summary>
/// The maximum number of idle sessions to track. This is used to limit the number of sessions that can be idle at once.
/// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached
/// their <see cref="IdleTimeout"/> until the idle session count is below this limit. Clients that keep their session open by
/// keeping a GET request open will not count towards this limit. The default value is set to 10,000 sessions.
/// </summary>
public int MaxIdleSessionCount { get; set; } = 10_000;

/// <summary>
/// Used for testing the <see cref="IdleTimeout"/>.
Expand Down
107 changes: 81 additions & 26 deletions src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,39 @@ namespace ModelContextProtocol.AspNetCore;
internal sealed partial class IdleTrackingBackgroundService(
StreamableHttpHandler handler,
IOptions<HttpServerTransportOptions> options,
IHostApplicationLifetime appLifetime,
ILogger<IdleTrackingBackgroundService> logger) : BackgroundService
{
// The compiler will complain about the parameter being unused otherwise despite the source generator.
private ILogger _logger = logger;

// We can make this configurable once we properly harden the MCP server. In the meantime, anyone running
// this should be taking a cattle not pets approach to their servers and be able to launch more processes
// to handle more than 10,000 idle sessions at a time.
private const int MaxIdleSessionCount = 10_000;

protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
var timeProvider = options.Value.TimeProvider;
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);
// Still run loop given infinite IdleTimeout to enforce the MaxIdleSessionCount and assist graceful shutdown.
if (options.Value.IdleTimeout != Timeout.InfiniteTimeSpan)
{
ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.IdleTimeout, TimeSpan.Zero);
}
ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.MaxIdleSessionCount, 0);

try
{
var timeProvider = options.Value.TimeProvider;
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);

var idleTimeoutTicks = options.Value.IdleTimeout.Ticks;
var maxIdleSessionCount = options.Value.MaxIdleSessionCount;

var idleSessions = new SortedSet<(string SessionId, long Timestamp)>(SessionTimestampComparer.Instance);

while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken))
{
var idleActivityCutoff = timeProvider.GetTimestamp() - options.Value.IdleTimeout.Ticks;
var idleActivityCutoff = idleTimeoutTicks switch
{
< 0 => long.MinValue,
var ticks => timeProvider.GetTimestamp() - ticks,
};

var idleCount = 0;
foreach (var (_, session) in handler.Sessions)
{
if (session.IsActive || session.SessionClosed.IsCancellationRequested)
Expand All @@ -38,34 +49,40 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
continue;
}

idleCount++;
if (idleCount == MaxIdleSessionCount)
{
// Emit critical log at most once every 5 seconds the idle count it exceeded,
//since the IdleTimeout will no longer be respected.
LogMaxSessionIdleCountExceeded();
}
else if (idleCount < MaxIdleSessionCount && session.LastActivityTicks > idleActivityCutoff)
if (session.LastActivityTicks < idleActivityCutoff)
{
RemoveAndCloseSession(session.Id);
continue;
}

if (handler.Sessions.TryRemove(session.Id, out var removedSession))
idleSessions.Add((session.Id, session.LastActivityTicks));

// Emit critical log at most once every 5 seconds the idle count it exceeded,
// since the IdleTimeout will no longer be respected.
if (idleSessions.Count == maxIdleSessionCount + 1)
{
LogSessionIdle(removedSession.Id);
LogMaxSessionIdleCountExceeded(maxIdleSessionCount);
}
}

// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
_ = DisposeSessionAsync(removedSession);
if (idleSessions.Count > maxIdleSessionCount)
{
var sessionsToPrune = idleSessions.ToArray()[..^maxIdleSessionCount];
foreach (var (id, _) in sessionsToPrune)
{
RemoveAndCloseSession(id);
}
}

idleSessions.Clear();
}
}
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
{
}
finally
{
if (stoppingToken.IsCancellationRequested)
try
{
List<Task> disposeSessionTasks = [];

Expand All @@ -79,7 +96,29 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)

await Task.WhenAll(disposeSessionTasks);
}
finally
{
if (!stoppingToken.IsCancellationRequested)
{
// Something went terribly wrong. A very unexpected exception must be bubbling up, but let's ensure we also stop the application,
// so that it hopefully gets looked at and restarted. This shouldn't really be reachable.
appLifetime.StopApplication();
IdleTrackingBackgroundServiceStoppedUnexpectedly();
}
}
}
}

private void RemoveAndCloseSession(string sessionId)
{
if (!handler.Sessions.TryRemove(sessionId, out var session))
{
return;
}

LogSessionIdle(session.Id);
// Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
_ = DisposeSessionAsync(session);
}

private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransport> session)
Expand All @@ -94,12 +133,28 @@ private async Task DisposeSessionAsync(HttpMcpSession<StreamableHttpServerTransp
}
}

private sealed class SessionTimestampComparer : IComparer<(string SessionId, long Timestamp)>
{
public static SessionTimestampComparer Instance { get; } = new();

public int Compare((string SessionId, long Timestamp) x, (string SessionId, long Timestamp) y) =>
x.Timestamp.CompareTo(y.Timestamp) switch
{
// Use a SessionId comparison as tiebreaker to ensure uniqueness in the SortedSet.
0 => string.CompareOrdinal(x.SessionId, y.SessionId),
var timestampComparison => timestampComparison,
};
}

[LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")]
private partial void LogSessionIdle(string sessionId);

[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded static maximum of 10,000 idle connections. Now clearing all inactive connections regardless of timeout.")]
private partial void LogMaxSessionIdleCountExceeded();

[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing the IMcpServer for session {sessionId}.")]
[LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")]
private partial void LogSessionDisposeError(string sessionId, Exception ex);

[LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")]
private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount);

[LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")]
private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly();
}
40 changes: 19 additions & 21 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Net.Http.Headers;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
Expand All @@ -23,18 +24,19 @@ internal sealed class StreamableHttpHandler(
IServiceProvider applicationServices)
{
private static JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();
private static MediaTypeHeaderValue ApplicationJsonMediaType = new("application/json");
private static MediaTypeHeaderValue TextEventStreamMediaType = new("text/event-stream");

public ConcurrentDictionary<string, HttpMcpSession<StreamableHttpServerTransport>> Sessions { get; } = new(StringComparer.Ordinal);

public async Task HandlePostRequestAsync(HttpContext context)
{
// The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream.
// ASP.NET Core Minimal APIs mostly ry to stay out of the business of response content negotiation, so
// we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, but it's
// probably good to at least start out trying to be strict.
var acceptHeader = context.Request.Headers.Accept.ToString();
if (!acceptHeader.Contains("application/json", StringComparison.Ordinal) ||
!acceptHeader.Contains("text/event-stream", StringComparison.Ordinal))
// ASP.NET Core Minimal APIs mostly try to stay out of the business of response content negotiation,
// so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests,
// but it's probably good to at least start out trying to be strict.
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(ApplicationJsonMediaType) || !acceptHeaders.Contains(TextEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept both application/json and text/event-stream",
Expand All @@ -49,9 +51,8 @@ await WriteJsonRpcErrorAsync(context,
}

using var _ = session.AcquireReference();
using var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, session.SessionClosed);
InitializeSseResponse(context);
var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), cts.Token);
var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted);
if (!wroteResponse)
{
// We wound up writing nothing, so there should be no Content-Type response header.
Expand All @@ -62,8 +63,8 @@ await WriteJsonRpcErrorAsync(context,

public async Task HandleGetRequestAsync(HttpContext context)
{
var acceptHeader = context.Request.Headers.Accept.ToString();
if (!acceptHeader.Contains("application/json", StringComparison.Ordinal))
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(TextEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept text/event-stream",
Expand Down Expand Up @@ -105,12 +106,6 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
}
}

private void InitializeSessionResponse(HttpContext context, HttpMcpSession<StreamableHttpServerTransport> session)
{
context.Response.Headers["mcp-session-id"] = session.Id;
context.Features.Set(session.Server);
}

private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>?> GetSessionAsync(HttpContext context, string sessionId)
{
if (Sessions.TryGetValue(sessionId, out var existingSession))
Expand All @@ -123,7 +118,8 @@ await WriteJsonRpcErrorAsync(context,
return null;
}

InitializeSessionResponse(context, existingSession);
context.Response.Headers["mcp-session-id"] = existingSession.Id;
context.Features.Set(existingSession.Server);
return existingSession;
}

Expand All @@ -138,11 +134,10 @@ await WriteJsonRpcErrorAsync(context,
private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>?> GetOrCreateSessionAsync(HttpContext context)
{
var sessionId = context.Request.Headers["mcp-session-id"].ToString();
HttpMcpSession<StreamableHttpServerTransport>? session;

if (string.IsNullOrEmpty(sessionId))
{
session = await CreateSessionAsync(context);
var session = await CreateSessionAsync(context);

if (!Sessions.TryAdd(session.Id, session))
{
Expand All @@ -159,6 +154,9 @@ await WriteJsonRpcErrorAsync(context,

private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> CreateSessionAsync(HttpContext context)
{
var sessionId = MakeNewSessionId();
context.Response.Headers["mcp-session-id"] = sessionId;

var mcpServerOptions = mcpServerOptionsSnapshot.Value;
if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
{
Expand All @@ -169,16 +167,16 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> CreateSes
var transport = new StreamableHttpServerTransport();
// Use application instead of request services, because the session will likely outlive the first initialization request.
var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices);
context.Features.Set(server);

var session = new HttpMcpSession<StreamableHttpServerTransport>(MakeNewSessionId(), transport, context.User, httpMcpServerOptions.Value.TimeProvider)
var session = new HttpMcpSession<StreamableHttpServerTransport>(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider)
{
Server = server,
};

var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync;
session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed);

InitializeSessionResponse(context, session);
return session;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Protocol/Transport/SseWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
{
Throw.IfNull(message);

using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false);

if (_disposed)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
}

using var getCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken);
await _sseWriter.WriteAllAsync(sseResponseStream, getCts.Token).ConfigureAwait(false);
// We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully.
await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
32 changes: 32 additions & 0 deletions tests/Common/Utils/MockLoggerProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.Extensions.Logging;
using System.Collections.Concurrent;

namespace ModelContextProtocol.Tests.Utils;

public class MockLoggerProvider() : ILoggerProvider
{
public ConcurrentQueue<(string Category, LogLevel LogLevel, string Message, Exception? Exception)> LogMessages { get; } = [];

public ILogger CreateLogger(string categoryName)
{
return new MockLogger(this, categoryName);
}

public void Dispose()
{
}

private class MockLogger(MockLoggerProvider mockProvider, string category) : ILogger
{
public void Log<TState>(
LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
mockProvider.LogMessages.Enqueue((category, logLevel, formatter(state, exception), exception));
}

public bool IsEnabled(LogLevel logLevel) => true;

// The MockLoggerProvider is a convenient NoopDisposable
public IDisposable BeginScope<TState>(TState state) where TState : notnull => mockProvider;
}
}
Loading
Loading