diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 2d1f4b43c..055b9b5f0 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -41,6 +41,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.context.ContextView; /** * The Model Context Protocol (MCP) client implementation that provides asynchronous @@ -314,13 +315,21 @@ public class McpAsyncClient { }; this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), - initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx)), + initializationTimeout, ctx -> createSession(ctx, requestTimeout, requestHandlers, notificationHandlers), postInitializationHook); this.transport.setExceptionHandler(this.initializer::handleException); } + /** + * An extension point to create a custom McpClientSession with additional context. + */ + protected McpClientSession createSession(ContextView ctx, Duration requestTimeout, + Map> requestHandlers, Map notificationHandlers) { + return new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, + con -> con.contextWrite(ctx)); + } + /** * Get the current initialization result. * @return the initialization result. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0ba7ab3b8..3ee8f8db2 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -145,7 +145,11 @@ private void dismissPendingResponses() { this.pendingResponses.clear(); } - private void handle(McpSchema.JSONRPCMessage message) { + /** + * An extension point for handling incoming JSON-RPC messages. + * @param message The incoming JSON-RPC message + */ + protected void handle(McpSchema.JSONRPCMessage message) { if (message instanceof McpSchema.JSONRPCResponse response) { logger.debug("Received response: {}", response); if (response.id() != null) { @@ -198,7 +202,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { * @param request The incoming JSON-RPC request * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + protected Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { return Mono.defer(() -> { var handler = this.requestHandlers.get(request.method()); if (handler == null) { @@ -231,7 +235,7 @@ private MethodNotFoundError getMethodNotFoundError(String method) { * @param notification The incoming JSON-RPC notification * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + protected Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { return Mono.defer(() -> { var handler = notificationHandlers.get(notification.method()); if (handler == null) {