diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe95..c519b5580 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -127,6 +127,28 @@ test + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + test + + diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index e60451706..3f7104fc5 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -29,6 +29,7 @@ import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.util.Assert; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -117,10 +118,6 @@ public static Builder builder(WebClient.Builder webClientBuilder) { public Mono connect(Function, Mono> handler) { return Mono.deferContextual(ctx -> { this.handler.set(handler); - if (openConnectionOnStartup) { - logger.debug("Eagerly opening connection on startup"); - return this.reconnect(null).then(); - } return Mono.empty(); }); } @@ -250,11 +247,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { }) .bodyValue(message) .exchangeToFlux(response -> { - if (transportSession - .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { - // Once we have a session, we try to open an async stream for - // the server to send notifications and requests out-of-band. - reconnect(null).contextWrite(sink.contextView()).subscribe(); + transportSession.markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id")); + if (response.statusCode().is2xxSuccessful() + && message instanceof JSONRPCNotification notification) { + if (notification.method().equals("notifications/initialized")) { + // Establish SSE stream after session is initialized + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } } String sessionRepresentation = sessionIdOrPlaceholder(transportSession); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 42b91d14e..1cc7b2feb 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -12,6 +12,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -161,7 +163,7 @@ void testBuilderPattern() { @Test void testMessageProcessing() { // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Simulate receiving the message @@ -192,7 +194,7 @@ void testResponseMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -216,7 +218,7 @@ void testErrorMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -246,7 +248,7 @@ void testGracefulShutdown() { StepVerifier.create(transport.closeGracefully()).verifyComplete(); // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message is not processed after shutdown @@ -292,10 +294,10 @@ void testMultipleMessageProcessing() { """); // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"), Map.of("key", "value1")); - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"), Map.of("key", "value2")); // Verify both messages are processed diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java new file mode 100644 index 000000000..51e06358d --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java @@ -0,0 +1,651 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpError; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +/** + * Integration tests for @link{StreamableHttpServerTransportProvider} with + * + * @link{WebClientStreamableHttpTransport}. + */ +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class StreamableHttpTransportIntegrationTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String ENDPOINT = "/mcp"; + + private StreamableHttpServerTransportProvider serverTransportProvider; + + private McpClient.AsyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + void setUp() { + serverTransportProvider = new StreamableHttpServerTransportProvider(new ObjectMapper(), ENDPOINT, null); + + // Set up session factory with proper server capabilities + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + null); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2024-11-05", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of(), java.util.Map.of())); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, serverTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + WebClientStreamableHttpTransport clientTransport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .build(); + + clientBuilder = McpClient.async(clientTransport) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")); + } + + @AfterEach + void tearDown() { + if (serverTransportProvider != null) { + serverTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + private void setupServerSession(McpServerFeatures.AsyncToolSpecification tool, Duration timeout, + Map> additionalHandlers) { + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + new McpSchema.ServerCapabilities.ToolCapabilities(true)); + Map> handlers = new java.util.HashMap<>(); + handlers.put("tools/call", + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> tool.call().apply(exchange, (Map) params)); + handlers.putAll(additionalHandlers); + serverTransportProvider.setStreamableHttpSessionFactory( + sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, timeout, + request -> Mono.just(new McpSchema.InitializeResult("2024-11-05", serverCapabilities, + new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> Mono.empty(), handlers, Map.of())); + } + + private void setupServerSession(McpServerFeatures.AsyncToolSpecification tool) { + setupServerSession(tool, Duration.ofSeconds(30), Map.of()); + } + + private io.modelcontextprotocol.client.McpAsyncClient buildClientWithCapabilities( + McpSchema.ClientCapabilities capabilities, + Function> samplingHandler, + Function> elicitationHandler) { + var builder = McpClient + .async(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .build()) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")); + if (capabilities != null) { + builder = builder.capabilities(capabilities); + } + if (samplingHandler != null) { + builder = builder.sampling(samplingHandler); + } + if (elicitationHandler != null) { + builder = builder.elicitation(elicitationHandler); + } + return builder.build(); + } + + private void executeTestWithClient(io.modelcontextprotocol.client.McpAsyncClient mcpClient, Runnable testLogic) { + try { + mcpClient.initialize().block(); + testLogic.run(); + } + finally { + mcpClient.close(); + } + } + + @Test + @Order(1) + void shouldInitializeSuccessfully() { + // The server is already configured via the session factory in setUp + var mcpClient = clientBuilder.build(); + try { + InitializeResult result = mcpClient.initialize().block(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("Test Server"); + } + finally { + mcpClient.close(); + } + } + + @Test + @Order(2) + void shouldCallToolSuccessfully() { + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool description", "{}"), + (exchange, request) -> Mono.just(callResponse)); + + setupServerSession(tool); + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()) + .isEqualTo("Tool executed successfully"); + }); + } + + @Test + @Order(3) + void shouldCallToolWithUpgradedTransportSuccessfully() { + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool description", "{}"), (exchange, request) -> { + exchange.upgradeTransport(); + return Mono.just(callResponse); + }); + + setupServerSession(tool); + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()) + .isEqualTo("Tool executed successfully"); + }); + } + + @Test + @Order(4) + void shouldReceiveNotificationThroughGetStream() throws InterruptedException { + CountDownLatch notificationLatch = new CountDownLatch(1); + AtomicReference receivedNotification = new AtomicReference<>(); + + // Build client with logging notification handler + var mcpClient = McpClient + .async(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .openConnectionOnStartup(true) + .build()) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")) + .loggingConsumers(List.of(notification -> { + if ("test message".equals(notification.data())) { + receivedNotification.set(notification.data()); + notificationLatch.countDown(); + } + return Mono.empty(); + })) + .build(); + + try { + mcpClient.initialize().block(); + + // Wait for post-initialize GET/listening stream establishment + Thread.sleep(500); + + // Send logging notification from server + serverTransportProvider + .notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + new McpSchema.LoggingMessageNotification(McpSchema.LoggingLevel.INFO, "server", "test message")) + .block(); + + // Wait for notification to be received + assertThat(notificationLatch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(receivedNotification.get()).isEqualTo("test message"); + } + finally { + mcpClient.close(); + } + } + + @Test + @Order(5) + void shouldCreateMessageSuccessfully() { + var createMessageResult = new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, + new McpSchema.TextContent("Test response"), "test-model", + McpSchema.CreateMessageResult.StopReason.END_TURN); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("create-message-tool", "Tool that creates messages", "{}"), (exchange, request) -> { + var createRequest = new McpSchema.CreateMessageRequest( + List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test prompt"))), + null, null, null, null, 0, null, null, null); + return exchange.createMessage(createRequest) + .then(Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Message created")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(30), Map.of(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> Mono.just(createMessageResult))); + + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().sampling().build(), + request -> Mono.just(new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, + new McpSchema.TextContent("Sampled response"), "test-model", + McpSchema.CreateMessageResult.StopReason.END_TURN)), + null); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("create-message-tool", Map.of())).block(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Message created"); + }); + } + + @Test + @Order(6) + void shouldCreateElicitationSuccessfully() { + var elicitResult = new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("response", "user response")); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("elicit-tool", "Tool that creates elicitations", "{}"), (exchange, request) -> { + var elicitRequest = new McpSchema.ElicitRequest("Please provide input", null, null); + return exchange.createElicitation(elicitRequest) + .then(Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Elicitation created")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(30), Map.of(McpSchema.METHOD_ELICITATION_CREATE, + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> Mono.just(elicitResult))); + + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().elicitation().build(), null, + request -> Mono.just(new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("response", "elicited response")))); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("elicit-tool", Map.of())).block(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Elicitation created"); + }); + } + + @Test + @Order(7) + void shouldListRootsSuccessfully() { + var listRootsResult = new McpSchema.ListRootsResult(List.of(new McpSchema.Root("file:///test", "Test root")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("list-roots-tool", "Tool that lists roots", "{}"), (exchange, request) -> { + return exchange.listRoots() + .then(Mono.just(new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Roots listed")), + null))); + }); + + setupServerSession(tool, Duration.ofSeconds(30), Map.of(McpSchema.METHOD_ROOTS_LIST, + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> Mono.just(listRootsResult))); + + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().roots(true).build(), null, + null); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("list-roots-tool", Map.of())).block(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Roots listed"); + }); + } + + @Test + @Order(8) + void shouldSendLoggingNotificationSuccessfully() { + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("logging-tool", "Tool that sends logging notifications", "{}"), + (exchange, request) -> { + var notification = new McpSchema.LoggingMessageNotification(McpSchema.LoggingLevel.INFO, + "test-logger", "Test log message"); + return exchange.loggingNotification(notification) + .then(Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Notification sent")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(30), + Map.of(McpSchema.METHOD_NOTIFICATION_MESSAGE, + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, params) -> Mono + .empty())); + + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-tool", Map.of())).block(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Notification sent"); + }); + } + + @Test + @Order(9) + void shouldPingSuccessfully() { + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("ping-tool", "Tool that sends ping requests", "{}"), (exchange, request) -> { + return exchange.ping() + .then(Mono + .just(new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Ping sent")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(30), + Map.of(McpSchema.METHOD_PING, + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> Mono.just(Map.of("pong", true)))); + + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-tool", Map.of())).block(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Ping sent"); + }); + } + + @Test + @Order(10) + void shouldHandleGetStreamsAndToolCallWithUpgradedTransport() throws InterruptedException { + CountDownLatch notificationLatch = new CountDownLatch(1); + AtomicReference receivedNotification = new AtomicReference<>(); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool description", "{}"), (exchange, request) -> { + exchange.upgradeTransport(); + return Mono.just(callResponse); + }); + + setupServerSession(tool); + + var mcpClient = McpClient + .async(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .openConnectionOnStartup(true) + .build()) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")) + .loggingConsumers(List.of(notification -> { + if ("test message".equals(notification.data())) { + receivedNotification.set(notification.data()); + notificationLatch.countDown(); + } + return Mono.empty(); + })) + .build(); + + try { + mcpClient.initialize().block(); + + // Wait for GET stream establishment + Thread.sleep(500); + + // Call tool with upgraded transport + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()) + .isEqualTo("Tool executed successfully"); + + // Send notification through GET stream + serverTransportProvider + .notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + new McpSchema.LoggingMessageNotification(McpSchema.LoggingLevel.INFO, "server", "test message")) + .block(); + + // Verify notification received + assertThat(notificationLatch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(receivedNotification.get()).isEqualTo("test message"); + } + finally { + mcpClient.close(); + } + } + + @Test + @Order(11) + void shouldFailCreateMessageWithoutSamplingCapabilities() { + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + return Mono.just(mock(CallToolResult.class)); + }); + + setupServerSession(tool); + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + assertThatExceptionOfType(RuntimeException.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + }).withMessageContaining("Client must be configured with sampling capabilities"); + }); + } + + @Test + @Order(12) + void shouldFailCreateElicitationWithoutElicitationCapabilities() { + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + return Mono.just(mock(CallToolResult.class)); + }); + + setupServerSession(tool); + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + assertThatExceptionOfType(RuntimeException.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + }).withMessageContaining("Client must be configured with elicitation capabilities"); + }); + } + + @Test + @Order(13) + void shouldFailListRootsWithoutCapability() { + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + return exchange.listRoots().then(Mono.just(mock(CallToolResult.class))); + }); + + setupServerSession(tool); + var mcpClient = clientBuilder.build(); + executeTestWithClient(mcpClient, () -> { + assertThatExceptionOfType(RuntimeException.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + }).withMessageContaining("Roots not supported"); + }); + } + + @Test + @Order(1000) + void shouldHandleCreateMessageTimeoutSuccess() { + Function> samplingHandler = request -> { + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return Mono + .just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent("Test message"), + "MockModel", McpSchema.CreateMessageResult.StopReason.STOP_SEQUENCE)); + }; + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(McpSchema.ModelPreferences.builder().build()) + .build(); + return exchange.createMessage(createMessageRequest) + .then(Mono.just(new CallToolResult(List.of(new McpSchema.TextContent("Success")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(3), Map.of()); + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().sampling().build(), + samplingHandler, null); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Success"); + }); + } + + @Test + @Order(1001) + void shouldHandleCreateElicitationTimeoutSuccess() { + Function> elicitationHandler = request -> { + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return Mono.just(new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message()))); + }; + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object")) + .build(); + return exchange.createElicitation(elicitRequest) + .then(Mono.just(new CallToolResult(List.of(new McpSchema.TextContent("Success")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(3), Map.of()); + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().elicitation().build(), null, + elicitationHandler); + executeTestWithClient(mcpClient, () -> { + var result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Success"); + }); + } + + @Test + @Order(1002) + void shouldHandleCreateMessageTimeoutFailure() { + Function> samplingHandler = request -> { + try { + TimeUnit.SECONDS.sleep(3); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return Mono + .just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent("Test message"), + "MockModel", McpSchema.CreateMessageResult.StopReason.STOP_SEQUENCE)); + }; + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(McpSchema.ModelPreferences.builder().build()) + .build(); + return exchange.createMessage(createMessageRequest) + .then(Mono.just(new CallToolResult(List.of(new McpSchema.TextContent("Success")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(2), Map.of()); + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().sampling().build(), + samplingHandler, null); + executeTestWithClient(mcpClient, () -> { + assertThatExceptionOfType(RuntimeException.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + }).withMessageContaining("Did not observe any item or terminal signal within"); + }); + } + + @Test + @Order(1003) + void shouldHandleCreateElicitationTimeoutFailure() { + Function> elicitationHandler = request -> { + try { + TimeUnit.SECONDS.sleep(3); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return Mono.just(new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message()))); + }; + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool", "{}"), (exchange, request) -> { + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object")) + .build(); + return exchange.createElicitation(elicitRequest) + .then(Mono.just(new CallToolResult(List.of(new McpSchema.TextContent("Success")), null))); + }); + + setupServerSession(tool, Duration.ofSeconds(2), Map.of()); + var mcpClient = buildClientWithCapabilities(McpSchema.ClientCapabilities.builder().elicitation().build(), null, + elicitationHandler); + executeTestWithClient(mcpClient, () -> { + assertThatExceptionOfType(RuntimeException.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + }).withMessageContaining("Did not observe any item or terminal signal within"); + }); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..a9fa4d5bb --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,63 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + TomcatTestUtil() { + // Prevent instantiation + } + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml new file mode 100644 index 000000000..37f43a17a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml @@ -0,0 +1,15 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..119275157 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,9 +183,9 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(listeningTransport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, listeningTransport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, + notificationHandlers)); } // --------------------------------------- @@ -214,7 +214,6 @@ private Mono asyncInitializeRequestHandler( "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", initializeRequest.protocolVersion(), serverProtocolVersion); } - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, this.serverInfo, this.instructions)); }); @@ -340,7 +339,8 @@ public Mono notifyToolsListChanged() { private McpServerSession.RequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + List tools = new ArrayList<>(); + tools.addAll(this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList()); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; @@ -356,12 +356,11 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( .filter(tr -> callToolRequest.name().equals(tr.tool().name())) .findAny(); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + if (toolSpecification.isPresent()) { + return toolSpecification.get().call().apply(exchange, callToolRequest.arguments()); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); }; } @@ -636,6 +635,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { new TypeReference() { }); + // This will update both the exchange and session logging levels exchange.setMinLoggingLevel(newMinLoggingLevel.level()); // FIXME: this field is deprecated and should be removed together diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index e56c695fa..ad5203cf1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -8,11 +8,14 @@ import java.util.Collections; import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -45,18 +48,22 @@ public class McpAsyncServerExchange { public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { }; + private final String transportId; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. * @param clientCapabilities The client capabilities that define the supported * features and functionality. * @param clientInfo The client implementation information. + * @param transportId The transport ID to use for outgoing messages */ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo) { + McpSchema.Implementation clientInfo, String transportId) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.transportId = transportId; } /** @@ -75,6 +82,28 @@ public McpSchema.Implementation getClientInfo() { return this.clientInfo; } + /** + * If the exchange's session is using StreamableHttp: Upgrades the transport + * referenced by this exchange's transportId to an SSE stream if it isn't already one. + */ + private void establishSseStream() { + final McpServerTransport currentTransport = session.getTransport(transportId); + if (session.isStreamableHttp() + && currentTransport instanceof StreamableHttpServerTransportProvider.HttpTransport transport) { + session.registerTransport(transportId, + new StreamableHttpServerTransportProvider.SseTransport(transport.getObjectMapper(), + transport.getResponse(), transport.getAsyncContext(), null, transportId, session.getId())); + } + } + + /** + * This is for tool writers to use if they want to send their tool response over an + * SSE stream without using any other McpAsyncServerExchange methods + */ + public void upgradeTransport() { + establishSseStream(); + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM @@ -92,6 +121,9 @@ public McpSchema.Implementation getClientInfo() { * Specification */ public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + + establishSseStream(); + if (this.clientCapabilities == null) { return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); } @@ -99,7 +131,7 @@ public Mono createMessage(McpSchema.CreateMessage return Mono.error(new McpError("Client must be configured with sampling capabilities")); } return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); + CREATE_MESSAGE_RESULT_TYPE_REF, transportId); } /** @@ -117,14 +149,17 @@ public Mono createMessage(McpSchema.CreateMessage * Specification */ public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + + establishSseStream(); + if (this.clientCapabilities == null) { return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); } if (this.clientCapabilities.elicitation() == null) { return Mono.error(new McpError("Client must be configured with elicitation capabilities")); } - return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, - ELICITATION_RESULT_TYPE_REF); + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICITATION_RESULT_TYPE_REF, + transportId); } /** @@ -153,8 +188,11 @@ public Mono listRoots() { * @return A Mono that emits the list of roots result containing */ public Mono listRoots(String cursor) { + + establishSseStream(); + return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); + LIST_ROOTS_RESULT_TYPE_REF, transportId); } /** @@ -165,13 +203,16 @@ public Mono listRoots(String cursor) { */ public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + establishSseStream(); + if (loggingMessageNotification == null) { return Mono.error(new McpError("Logging message must not be null")); } return Mono.defer(() -> { if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { - return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification, + transportId); } return Mono.empty(); }); @@ -182,7 +223,10 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN * @return A Mono that completes with clients's ping response */ public Mono ping() { - return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); + + establishSseStream(); + + return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF, transportId); } /** @@ -190,9 +234,11 @@ public Mono ping() { * filtered out. * @param minLoggingLevel The minimum logging level */ - void setMinLoggingLevel(LoggingLevel minLoggingLevel) { + public void setMinLoggingLevel(LoggingLevel minLoggingLevel) { Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); this.minLoggingLevel = minLoggingLevel; + // Also update the session level for future exchanges + this.session.setMinLoggingLevel(minLoggingLevel); } private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index e61722a82..9d7766557 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -14,6 +14,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -58,7 +59,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * the roots list changes * @param instructions The server instructions text */ - Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + public Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, List resourceTemplates, Map prompts, diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java new file mode 100644 index 000000000..5a4331cf5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +/** + * Handler interface for session lifecycle and runtime events in the Streamable HTTP + * transport. + * + *

+ * This interface provides hooks for monitoring and responding to various session-related + * events that occur during the operation of the HTTP-based MCP server transport. + * Implementations can use these callbacks to: + *

    + *
  • Log session activities
  • + *
  • Implement custom session management logic
  • + *
  • Handle error conditions
  • + *
  • Perform cleanup operations
  • + *
+ * + * @author Zachary German + */ +public interface SessionHandler { + + /** + * Called when a new session is created. + * @param sessionId The ID of the newly created session + * @param context Additional context information (may be null) + */ + void onSessionCreate(String sessionId, Object context); + + /** + * Called when a session is closed. + * @param sessionId The ID of the closed session + */ + void onSessionClose(String sessionId); + + /** + * Called when a session is not found for a given session ID. + * @param sessionId The session ID that was not found + * @param request The HTTP request that referenced the missing session + * @param response The HTTP response that will be sent to the client + */ + void onSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response); + + /** + * Called when an error occurs while sending a notification to a session. + * @param sessionId The ID of the session where the error occurred + * @param error The error that occurred + */ + void onSendNotificationError(String sessionId, Throwable error); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 000000000..d32038b78 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,893 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.SseEvent; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.nio.charset.StandardCharsets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +import static java.util.Objects.requireNonNullElse; + +/** + * MCP Streamable HTTP transport provider that uses a single session class to manage all + * streams and transports. + * + *

+ * Key improvements over the original implementation: + *

    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests, providing direct-HTTP or SSE-streamed responses. + *
  • Provides callbacks for session lifecycle and errors. + *
  • Enforces allowed 'Origin' header values if configured. + *
  • Provides a default constructor and default values for all constructor parameters. + *
+ * + * @author Zachary German + */ +@WebServlet(asyncSupported = true) +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String SESSION_ID_HEADER = "Mcp-Session-Id"; + + public static final String LAST_EVENT_ID_HEADER = "Last-Event-Id"; + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String ACCEPT_HEADER = "Accept"; + + public static final String ORIGIN_HEADER = "Origin"; + + public static final String ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin"; + + public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*"; + + public static final String PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"; + + public static final String CACHE_CONTROL_HEADER = "Cache-Control"; + + public static final String CONNECTION_HEADER = "Connection"; + + public static final String CACHE_CONTROL_NO_CACHE = "no-cache"; + + public static final String CONNECTION_KEEP_ALIVE = "keep-alive"; + + public static final String MCP_SESSION_ID = "MCP-Session-ID"; + + public static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + /** com.fasterxml.jackson.databind.ObjectMapper */ + private static final ObjectMapper DEFAULT_OBJECT_MAPPER = new ObjectMapper(); + + /** UUID.randomUUID().toString() */ + private static final Supplier DEFAULT_SESSION_ID_PROVIDER = () -> UUID.randomUUID().toString(); + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling MCP requests */ + private final String mcpEndpoint; + + /** Supplier for generating unique session IDs */ + private final Supplier sessionIdProvider; + + /** Sessions map, keyed by Session ID */ + private static final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Optional allowed 'Origin' header value list. Not enforced if empty. */ + private final List allowedOrigins = new ArrayList<>(); + + /** Callback interface for session lifecycle and errors */ + private SessionHandler sessionHandler; + + /** Factory for McpServerSession takes session IDs */ + private McpServerSession.StreamableHttpSessionFactory streamableHttpSessionFactory; + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + Supplier sessionIdProvider) { + this.objectMapper = requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER); + this.mcpEndpoint = requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT); + this.sessionIdProvider = requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER); + } + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider() { + this(null, null, null); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Required but not used for this implementation + } + + public void setStreamableHttpSessionFactory(McpServerSession.StreamableHttpSessionFactory sessionFactory) { + this.streamableHttpSessionFactory = sessionFactory; + } + + public void setSessionHandler(SessionHandler sessionHandler) { + this.sessionHandler = sessionHandler; + } + + public void setAllowedOrigins(List allowedOrigins) { + this.allowedOrigins.clear(); + this.allowedOrigins.addAll(allowedOrigins); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .filter(session -> session.getState() == session.STATE_INITIALIZED) + .flatMap(session -> session.sendNotification(method, params).doOnError(e -> { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + if (sessionHandler != null) { + sessionHandler.onSendNotificationError(session.getId(), e); + } + }).onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.closeGracefully() + .doOnError(e -> logger.error("Error closing session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + }); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("GET request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null || !acceptHeader.contains(TEXT_EVENT_STREAM)) { + logger.debug("Accept header missing or does not include {}", TEXT_EVENT_STREAM); + sendErrorResponse(response, "Accept header must include text/event-stream"); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + else { + response.setHeader(SESSION_ID_HEADER, sessionId); + } + + McpServerSession session = sessions.get(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + // Delayed until version negotiation is implemented. + /* + * if (session.getState().equals(session.STATE_INITIALIZED) && + * request.getHeader(PROTOCOL_VERSION_HEADER) == null) { + * sendErrorResponse(response, "Protocol version missing in request header"); } + */ + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + String lastEventId = request.getHeader(LAST_EVENT_ID_HEADER); + + if (lastEventId == null) { // Just opening a listening stream + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId, + session.LISTENING_TRANSPORT, sessionId); + logger.debug("Registered SSE transport {} for session {}", session.LISTENING_TRANSPORT, sessionId); + } + else { // Asking for a stream to replay events from a previous request + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId, + request.getRequestId(), sessionId); + logger.debug("Registered SSE transport {} for session {}", request.getRequestId(), sessionId); + } + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null + || (!acceptHeader.contains(APPLICATION_JSON) || !acceptHeader.contains(TEXT_EVENT_STREAM))) { + logger.debug("Accept header validation failed. Header: {}", acceptHeader); + sendErrorResponse(response, "Accept header must include both application/json and text/event-stream"); + return; + } + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + StringBuilder body = new StringBuilder(); + ServletInputStream inputStream = request.getInputStream(); + + inputStream.setReadListener(new ReadListener() { + @Override + public void onDataAvailable() throws IOException { + int len; + byte[] buffer = new byte[1024]; + while (inputStream.isReady() && (len = inputStream.read(buffer)) != -1) { + body.append(new String(buffer, 0, len, StandardCharsets.UTF_8)); + } + } + + @Override + public void onAllDataRead() throws IOException { + try { + logger.debug("Parsing JSON-RPC message: {}", body.toString()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + boolean isInitializeRequest = false; + String sessionId = request.getHeader(SESSION_ID_HEADER); + + if (message instanceof McpSchema.JSONRPCRequest req + && McpSchema.METHOD_INITIALIZE.equals(req.method())) { + isInitializeRequest = true; + logger.debug("Detected initialize request"); + if (sessionId == null) { + sessionId = sessionIdProvider.get(); + logger.debug("Created new session ID for initialize request: {}", sessionId); + } + } + + if (!isInitializeRequest && sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + asyncContext.complete(); + return; + } + + McpServerSession session = getOrCreateSession(sessionId, isInitializeRequest); + if (session == null) { + logger.error("Failed to create session for sessionId: {}", sessionId); + handleSessionNotFound(sessionId, request, response); + asyncContext.complete(); + return; + } + + // Delayed until version negotiation is implemented. + /* + * if (session.getState().equals(session.STATE_INITIALIZED) && + * request.getHeader(PROTOCOL_VERSION_HEADER) == null) { + * sendErrorResponse(response, + * "Protocol version missing in request header"); } + */ + + logger.debug("Using session: {}", sessionId); + + response.setHeader(SESSION_ID_HEADER, sessionId); + + final String transportId; + if (message instanceof JSONRPCRequest req) { + transportId = req.id().toString(); + } + else if (message instanceof JSONRPCResponse resp) { + transportId = resp.id().toString(); + } + else { + transportId = null; + } + + // Only set content type for requests + if (message instanceof McpSchema.JSONRPCRequest) { + response.setContentType(APPLICATION_JSON); + } + + if (transportId != null) { // Not needed for notifications (null + // transportId) + HttpTransport httpTransport = new HttpTransport(objectMapper, response, asyncContext); + session.registerTransport(transportId, httpTransport); + } + + // For notifications, we need to handle the HTTP response manually + // since no JSON response is sent + if (message instanceof McpSchema.JSONRPCNotification) { + session.handle(message).doOnSuccess(v -> { + logger.debug("[NOTIFICATION] Sending empty HTTP response for notification"); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_OK); + response.setCharacterEncoding("UTF-8"); + } + asyncContext.complete(); + } + catch (Exception e) { + logger.error("Failed to send notification response: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in message handling: {}", e.getMessage(), e); + asyncContext.complete(); + }).contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + } + else { + // For requests, let the transport handle the response + session.handle(message) + .doOnSuccess(v -> logger.info("Message handling completed successfully for transport: {}", + transportId)) + .doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e)) + .doFinally(signalType -> { + logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType); + session.unregisterTransport(transportId); + }) + .contextWrite(Context.of(MCP_SESSION_ID, sessionId)) + .subscribe(null, error -> { + logger.error("Error in message handling chain: {}", error.getMessage(), error); + asyncContext.complete(); + }); + } + + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + sendErrorResponse(response, "Invalid JSON-RPC message: " + e.getMessage()); + asyncContext.complete(); + } + } + + @Override + public void onError(Throwable t) { + logger.error("Error reading request body: {}", t.getMessage()); + try { + sendErrorResponse(response, "Error reading request: " + t.getMessage()); + } + catch (IOException e) { + logger.error("Failed to write error response", e); + } + asyncContext.complete(); + } + }); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + else { + response.setHeader(SESSION_ID_HEADER, sessionId); + } + + McpServerSession session = sessions.remove(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + session.closeGracefully().contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + logger.debug("Session closed via DELETE request: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionClose(sessionId); + } + + response.setStatus(HttpServletResponse.SC_OK); + } + + private boolean validateOrigin(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (!allowedOrigins.isEmpty()) { + String origin = request.getHeader(ORIGIN_HEADER); + if (!allowedOrigins.contains(origin)) { + logger.debug("Origin header does not match allowed origins: '{}'", origin); + response.sendError(HttpServletResponse.SC_FORBIDDEN); + return false; + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, origin); + } + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, ALLOW_ORIGIN_DEFAULT_VALUE); + } + return true; + } + + private boolean validateEndpoint(String requestURI, HttpServletResponse response) throws IOException { + if (!requestURI.endsWith(mcpEndpoint)) { + logger.debug("URI does not match MCP endpoint: '{}'", mcpEndpoint); + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return false; + } + return true; + } + + private boolean validateNotClosing(HttpServletResponse response) throws IOException { + if (isClosing.get()) { + logger.debug("Server is shutting down, rejecting request"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return false; + } + return true; + } + + protected McpServerSession getOrCreateSession(String sessionId, boolean createIfMissing) { + McpServerSession session = sessions.get(sessionId); + logger.debug("Looking for session: {}, found: {}", sessionId, session != null); + if (session == null && createIfMissing) { + logger.debug("Creating new session: {}", sessionId); + session = streamableHttpSessionFactory.create(sessionId); + session.setIsStreamableHttp(true); // TODO: Remove this if we split SHTTP server + // session to its own class. + sessions.put(sessionId, session); + logger.debug("Created new session: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionCreate(sessionId, null); + } + } + return session; + } + + private void handleSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response) + throws IOException { + sendErrorResponse(response, "Session not found: " + sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionNotFound(sessionId, request, response); + } + } + + private void sendErrorResponse(HttpServletResponse response, String message) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson(message)); + } + + private String createErrorJson(String message) { + try { + return objectMapper.writeValueAsString(new McpError(message)); + } + catch (IOException e) { + logger.error("Failed to serialize error message", e); + return "{\"error\":\"" + message + "\"}"; + } + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + private Map extractHeaders(HttpServletRequest request) { + Map headers = new HashMap<>(); + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String name = headerNames.nextElement(); + headers.put(name, request.getHeader(name)); + } + return headers; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ObjectMapper objectMapper = DEFAULT_OBJECT_MAPPER; + + private String mcpEndpoint = DEFAULT_MCP_ENDPOINT; + + private Supplier sessionIdProvider = DEFAULT_SESSION_ID_PROVIDER; + + public Builder withObjectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withMcpEndpoint(String mcpEndpoint) { + Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + public Builder withSessionIdProvider(Supplier sessionIdProvider) { + Assert.notNull(sessionIdProvider, "SessionIdProvider must not be null"); + this.sessionIdProvider = sessionIdProvider; + return this; + } + + public StreamableHttpServerTransportProvider build() { + return new StreamableHttpServerTransportProvider(objectMapper, mcpEndpoint, sessionIdProvider); + } + + } + + private enum ResponseType { + + IMMEDIATE, STREAM + + } + + /** + * SSE transport implementation. + */ + public static class SseTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(SseTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + private final Sinks.Many eventSink = Sinks.many().unicast().onBackpressureBuffer(); + + private final Map eventHistory = new ConcurrentHashMap<>(); + + private final String id; + + private final String sessionId; + + public SseTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext, + String lastEventId, String transportId, String sessionId) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + this.id = transportId; + this.sessionId = sessionId; + + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + + logger.debug("Establishing SSE stream with ID: {} for session: {}", transportId, sessionId); + setupSseStream(lastEventId); + + logger.debug("Registering SSE transport with ID: {} for session: {}", transportId, sessionId); + sessions.get(sessionId).registerTransport(transportId, this); + } + + private void setupSseStream(String lastEventId) { + try { + PrintWriter writer = response.getWriter(); + + eventSink.asFlux().doOnNext(event -> { + try { + if (event.id() != null) { + writer.write("id: " + event.id() + "\n"); + } + if (event.event() != null) { + writer.write("event: " + event.event() + "\n"); + } + writer.write("data: " + event.data() + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + catch (IOException e) { + logger.debug("Error writing to SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnComplete(() -> { + try { + writer.close(); + } + finally { + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in SSE stream: {}", e.getMessage()); + asyncContext.complete(); + }).contextWrite(Context.of(MCP_SESSION_ID, response.getHeader(SESSION_ID_HEADER))).subscribe(); + + // Replay events if requested + if (lastEventId != null) { + replayEventsAfter(lastEventId); + } + + } + catch (IOException e) { + logger.error("Failed to setup SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + + private void replayEventsAfter(String lastEventId) { + try { + McpServerSession session = sessions.get(sessionId); + String transportIdOfLastEventId = session.getTransportIdForEvent(lastEventId); + Map transportEventHistory = session + .getTransportEventHistory(transportIdOfLastEventId); + List eventIds = transportEventHistory.keySet() + .stream() + .map(Long::parseLong) + .filter(key -> key > Long.parseLong(lastEventId)) + .sorted() + .map(String::valueOf) + .collect(Collectors.toList()); + for (String eventId : eventIds) { + SseEvent event = transportEventHistory.get(eventId); + if (event != null) { + eventSink.tryEmitNext(event); + } + } + logger.debug("Completing SSE stream after replaying events"); + eventSink.tryEmitComplete(); + } + catch (NumberFormatException e) { + logger.warn("Invalid last event ID: {}", lastEventId); + } + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = sessions.get(sessionId).incrementAndGetEventId(id); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + + if (message instanceof McpSchema.JSONRPCResponse) { + logger.debug("Completing SSE stream after sending response"); + eventSink.tryEmitComplete(); + McpServerSession session = sessions.get(sessionId); + if (session != null) { + session.setTransportEventHistory(id, eventHistory); + } + } + + return Mono.empty(); + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage()); + return Mono.error(e); + } + } + + /** + * Sends a stream of messages for Flux responses. + */ + public Mono sendMessageStream(Flux messageStream) { + return messageStream.doOnNext(message -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = sessions.get(sessionId).incrementAndGetEventId(id); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE stream event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + } + catch (Exception e) { + logger.error("Failed to send stream message: {}", e.getMessage()); + eventSink.tryEmitError(e); + } + }).doOnComplete(() -> { + logger.debug("Completing SSE stream after sending all stream messages"); + eventSink.tryEmitComplete(); + McpServerSession session = sessions.get(sessionId); + if (session != null) { + session.setTransportEventHistory(id, eventHistory); + } + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + eventSink.tryEmitComplete(); + McpServerSession session = sessions.get(sessionId); + if (session != null) { + session.setTransportEventHistory(id, eventHistory); + } + logger.debug("SSE transport closed gracefully"); + }); + } + + } + + /** + * HTTP transport implementation for immediate responses. + */ + public static class HttpTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + public HttpTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + } + + public ObjectMapper getObjectMapper() { + return this.objectMapper; + } + + public HttpServletResponse getResponse() { + return this.response; + } + + public AsyncContext getAsyncContext() { + return this.asyncContext; + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + if (response.isCommitted()) { + logger.warn("Response already committed, cannot send message"); + return; + } + + response.setCharacterEncoding("UTF-8"); + response.setStatus(HttpServletResponse.SC_OK); + + // For notifications, don't write any content (empty response) + if (message instanceof McpSchema.JSONRPCNotification) { + logger.debug("Sending empty 200 response for notification"); + // Just complete the response with no content + } + else { + // For requests/responses, write JSON content + String jsonText = objectMapper.writeValueAsString(message); + PrintWriter writer = response.getWriter(); + writer.write(jsonText); + writer.flush(); + logger.debug("Successfully sent immediate response: {}", jsonText); + } + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage(), e); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + catch (Exception ignored) { + } + } + finally { + asyncContext.complete(); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + asyncContext.complete(); + } + catch (Exception e) { + logger.debug("Error completing async context: {}", e.getMessage()); + } + logger.debug("HTTP transport closed gracefully"); + }); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index cc7d2abf8..454f1fc4b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.spec; import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.util.Assert; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -47,7 +49,7 @@ public class McpClientSession implements McpSession { private final McpClientTransport transport; /** Map of pending responses keyed by request ID */ - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); /** Map of request handlers keyed by method name */ private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); @@ -231,10 +233,10 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti /** * Generates a unique request ID in a non-blocking way. Combines a session-specific * prefix with an atomic counter to ensure uniqueness. - * @return A unique request ID string + * @return A unique request ID from String */ - private String generateRequestId() { - return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + private McpId generateRequestId() { + return McpId.of(this.sessionPrefix + "-" + this.requestCounter.getAndIncrement()); } /** @@ -247,7 +249,7 @@ private String generateRequestId() { */ @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); + McpId requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { logger.debug("Sending message for method {}", method); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java index 13e43240b..6d177e0f9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -15,7 +15,7 @@ public McpError(JSONRPCError jsonRpcError) { } public McpError(Object error) { - super(error.toString()); + super(String.valueOf(error)); } public JSONRPCError getJsonRpcError() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 9be585cea..f7414d0c9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -18,12 +18,25 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static java.util.Objects.requireNonNull; + /** * Based on the JSON-RPC 2.0 * specification and the { + + @Override + public McpId deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + JsonToken t = p.getCurrentToken(); + if (t == JsonToken.VALUE_STRING) { + return new McpId(p.getText()); + } + else if (t.isNumeric()) { + return new McpId(p.getNumberValue()); + } + throw JsonMappingException.from(p, "MCP 'id' must be a non-null String or Number"); + } + + } + + public static class Serializer extends JsonSerializer { + + @Override + public void serialize(McpId id, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if (id.isString()) { + gen.writeString(id.asString()); + } + else { + gen.writeNumber(id.asNumber().toString()); + } + } + + } + + } + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, GetPromptRequest, PaginatedRequest, ReadResourceRequest { @@ -200,18 +312,16 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCRequest( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, - @JsonProperty("id") Object id, + @JsonProperty("id") McpId id, @JsonProperty("params") Object params) implements JSONRPCMessage { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @@ -221,11 +331,10 @@ public record JSONRPCNotification( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCResponse( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("id") Object id, + @JsonProperty("id") McpId id, @JsonProperty("result") Object result, @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..53ca15dfd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,16 +1,24 @@ package io.modelcontextprotocol.spec; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.SseEvent; +import io.modelcontextprotocol.spec.McpSchema.McpId; +import io.modelcontextprotocol.spec.McpError; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; @@ -23,7 +31,29 @@ public class McpServerSession implements McpSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + /** + * Map of all registered session transport instances, keyed by the ID of the request + * which invoked their instantiation + */ + private final ConcurrentHashMap transports = new ConcurrentHashMap<>(); + + /** Generic SSE transport established by GET calls for listening to the server */ + private McpServerTransport listeningTransport; + + public static final String LISTENING_TRANSPORT = "listeningTransport"; + + private final AtomicLong eventCounter = new AtomicLong(0); + + /** Maps a given event ID to the transport ID that it was sent over */ + private final Map eventTransports = new ConcurrentHashMap<>(); + + /** + * Maps SSE transport IDs to a Map containing all events sent over them keyed by event + * ID, added upon the transport's termination + */ + private final Map> transportEventHistories = new ConcurrentHashMap<>(); private final String id; @@ -40,26 +70,34 @@ public class McpServerSession implements McpSession { private final Map notificationHandlers; - private final McpServerTransport transport; - - private final Sinks.One exchangeSink = Sinks.one(); - private final AtomicReference clientCapabilities = new AtomicReference<>(); private final AtomicReference clientInfo = new AtomicReference<>(); - private static final int STATE_UNINITIALIZED = 0; + public static final int STATE_UNINITIALIZED = 0; - private static final int STATE_INITIALIZING = 1; + public static final int STATE_INITIALIZING = 1; - private static final int STATE_INITIALIZED = 2; + public static final int STATE_INITIALIZED = 2; private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + // TODO: Remove this if we split SHTTP Server Session into its own class. + private AtomicBoolean isStreamableHttp = new AtomicBoolean(false); + + public void setIsStreamableHttp(boolean b) { + isStreamableHttp.set(b); + } + + public boolean isStreamableHttp() { + return isStreamableHttp.get(); + } + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id - * @param transport the transport to use * @param initHandler called when a * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the * server @@ -69,18 +107,43 @@ public class McpServerSession implements McpSession { * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ - public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + public McpServerSession(String id, Duration requestTimeout, McpServerTransport listeningTransport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map> requestHandlers, Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; - this.transport = transport; + this.listeningTransport = listeningTransport; this.initRequestHandler = initHandler; this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } + // Alternate constructor used by StreamableHttp servers + public McpServerSession(String id, Duration requestTimeout, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { + this(id, requestTimeout, null, initHandler, initNotificationHandler, requestHandlers, notificationHandlers); + } + + /** + * Updates the session's minimum logging level for all future exchanges. + */ + public void setMinLoggingLevel(McpSchema.LoggingLevel level) { + if (level != null) { + this.minLoggingLevel = level; + logger.debug("Updated session {} minimum logging level to {}", id, level); + } + } + + /** + * Retrieve the session initialization state + * @return session initialization state + */ + public int getState() { + return state.intValue(); + } + /** * Retrieve the session id. * @return session id @@ -89,6 +152,90 @@ public String getId() { return this.id; } + /** + * Increments the session-specific event counter, maps it to the given transport ID + * for replayability support, then returns the event ID + * @param transportId + * @return an event ID unique to the session + */ + public String incrementAndGetEventId(String transportId) { + final String eventId = String.valueOf(eventCounter.incrementAndGet()); + eventTransports.put(eventId, transportId); + return eventId; + } + + /** + * Used for replayability support to get the transport ID of a given event ID + * @param eventId + * @return The ID of the transport instance that the given event ID was sent over + */ + public String getTransportIdForEvent(String eventId) { + return eventTransports.get(eventId); + } + + /** + * Used for replayability support to set the event history of a given transport ID + * @param transportId + * @param eventHistory + */ + public void setTransportEventHistory(String transportId, Map eventHistory) { + transportEventHistories.put(transportId, eventHistory); + } + + /** + * Used for replayability support to retrieve the entire event history for a given + * transport ID + * @param transportId + * @return Map of SseEvent objects, keyed by event ID + */ + public Map getTransportEventHistory(String transportId) { + return transportEventHistories.get(transportId); + } + + /** + * Registers a transport for this session. + * @param transportId unique identifier for the transport + * @param transport the transport instance + */ + public void registerTransport(String transportId, McpServerTransport transport) { + if (transportId.equals(LISTENING_TRANSPORT)) { + this.listeningTransport = transport; + logger.debug("Registered listening transport for session {}", id); + return; + } + transports.put(transportId, transport); + logger.debug("Registered transport {} for session {}", transportId, id); + } + + /** + * Unregisters a transport from this session. + * @param transportId the transport identifier to remove + */ + public void unregisterTransport(String transportId) { + if (transportId.equals(LISTENING_TRANSPORT)) { + this.listeningTransport = null; + logger.debug("Unregistered listening transport for session {}", id); + return; + } + McpServerTransport removed = transports.remove(transportId); + if (removed != null) { + logger.debug("Unregistered transport {} from session {}", transportId, id); + } + } + + /** + * Gets a transport by its identifier. + * @param transportId the transport identifier + * @return the transport, or null if not found + */ + public McpServerTransport getTransport(String transportId) { + if (transportId.equals(LISTENING_TRANSPORT)) { + return this.listeningTransport; + } + logger.debug("Found transport {} in session {}", transportId, id); + return transports.get(transportId); + } + /** * Called upon successful initialization sequence between the client and the server * with the client capabilities and information. @@ -104,19 +251,47 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl this.clientInfo.lazySet(clientInfo); } - private String generateRequestId() { - return this.id + "-" + this.requestCounter.getAndIncrement(); + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities.get(); + } + + public McpSchema.Implementation getClientInfo() { + return this.clientInfo.get(); + } + + private McpId generateRequestId() { + return McpId.of(this.id + "-" + this.requestCounter.getAndIncrement()); + } + + /** + * Gets a request handler by method name. + */ + public RequestHandler getRequestHandler(String method) { + return requestHandlers.get(method); } @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); + return sendRequest(method, requestParams, typeRef, LISTENING_TRANSPORT); + } + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef, String transportId) { + McpServerTransport transport = getTransport(transportId); + if (transport == null) { + // Fallback to listening transport if specific transport not found + transport = getTransport(LISTENING_TRANSPORT); + if (transport == null) { + return Mono.error(new RuntimeException("Transport not found: " + transportId)); + } + } + + final McpServerTransport finalTransport = transport; + McpId requestId = this.generateRequestId(); return Mono.create(sink -> { this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + Flux.from(finalTransport.sendMessage(jsonrpcRequest)).subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); sink.error(error); @@ -125,22 +300,33 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } + else if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } else { - if (typeRef.getType().equals(Void.class)) { - sink.complete(); - } - else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); - } + T result = finalTransport.unmarshalFrom(jsonRpcResponse.result(), typeRef); + sink.next(result); } }); } @Override public Mono sendNotification(String method, Object params) { + return sendNotification(method, params, LISTENING_TRANSPORT); + } + + public Mono sendNotification(String method, Object params, String transportId) { + McpServerTransport transport = getTransport(transportId); + if (transport == null) { + // Fallback to listening transport if specific transport not found + transport = getTransport(LISTENING_TRANSPORT); + if (transport == null) { + return Mono.error(new RuntimeException("Transport not found: " + transportId)); + } + } McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); - return this.transport.sendMessage(jsonrpcNotification); + return transport.sendMessage(jsonrpcNotification); } /** @@ -170,13 +356,22 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); + final String transportId = determineTransportId(request); return handleIncomingRequest(request).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - // TODO: Should the error go to SSE or back as POST return? - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); + McpServerTransport transport = getTransportWithFallback(transportId); + return transport != null ? transport.sendMessage(errorResponse).then(Mono.empty()) : Mono.empty(); + }).flatMap(response -> { + McpServerTransport transport = getTransportWithFallback(transportId); + if (transport != null) { + return transport.sendMessage(response); + } + else { + return Mono.error(new RuntimeException("No transport available")); + } + }); } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization @@ -203,8 +398,10 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { + McpSchema.InitializeRequest initializeRequest = transports.isEmpty() ? listeningTransport + .unmarshalFrom(request.params(), new TypeReference() { + }) : transports.get(String.valueOf(request.id())) + .unmarshalFrom(request.params(), new TypeReference() { }); this.state.lazySet(STATE_INITIALIZING); @@ -222,7 +419,10 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + McpAsyncServerExchange requestExchange = new McpAsyncServerExchange(this, clientCapabilities.get(), + clientInfo.get(), determineTransportId(request)); + requestExchange.setMinLoggingLevel(minLoggingLevel); + resultMono = handler.handle(requestExchange, request.params()); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -242,7 +442,6 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); return this.initNotificationHandler.handle(); } @@ -251,7 +450,10 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + McpAsyncServerExchange notificationExchange = new McpAsyncServerExchange(this, clientCapabilities.get(), + clientInfo.get(), LISTENING_TRANSPORT); + notificationExchange.setMinLoggingLevel(minLoggingLevel); + return handler.handle(notificationExchange, notification.params()); }); } @@ -262,14 +464,56 @@ private MethodNotFoundError getMethodNotFoundError(String method) { return new MethodNotFoundError(method, "Method not found: " + method, null); } + /** + * Determines the appropriate transport ID for a request. Uses request ID for + * per-request routing only if a transport with that ID exists, otherwise falls back + * to listening transport. + */ + private String determineTransportId(McpSchema.JSONRPCRequest request) { + String requestTransportId = request.id().toString(); + // Check if a transport exists for this specific request ID + if (getTransport(requestTransportId) != null) { + return requestTransportId; + } + // Fallback to listening transport + return LISTENING_TRANSPORT; + } + + /** + * Gets a transport with fallback to listening transport. + */ + private McpServerTransport getTransportWithFallback(String transportId) { + McpServerTransport transport = getTransport(transportId); + if (transport == null) { + transport = getTransport(LISTENING_TRANSPORT); + } + return transport; + } + @Override public Mono closeGracefully() { - return this.transport.closeGracefully(); + return Mono.defer(() -> { + List> closeTasks = new ArrayList<>(); + + // Add listening transport if it exists + if (listeningTransport != null) { + closeTasks.add(listeningTransport.closeGracefully()); + } + + // Add all transports from the map + closeTasks.addAll(transports.values().stream().map(McpServerTransport::closeGracefully).toList()); + + return Mono.when(closeTasks); + }); } @Override public void close() { - this.transport.close(); + if (listeningTransport != null) { + listeningTransport.close(); + } + transports.values().forEach(McpServerTransport::close); + transports.clear(); } /** @@ -334,6 +578,25 @@ public interface RequestHandler { } + /** + * A handler for client-initiated requests return Flux. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + public interface StreamingRequestHandler extends RequestHandler { + + /** + * Handles a request from the client which invokes a streamTool. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return Flux that will emit the response to the request. + */ + Flux handleStreaming(McpAsyncServerExchange exchange, Object params); + + } + /** * Factory for creating server sessions which delegate to a provided 1:1 transport * with a connected client. @@ -350,4 +613,21 @@ public interface Factory { } + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface StreamableHttpSessionFactory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param transportId ID of the JSONRPCRequest/JSONRPCResponse the transport is + * serving. + * @return a new server session. + */ + McpServerSession create(String transportId); + + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java b/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java new file mode 100644 index 000000000..f5f288cdd --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java @@ -0,0 +1,4 @@ +package io.modelcontextprotocol.spec; + +public record SseEvent(String id, String event, String data) { +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 3e89c8cef..cde86fe9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -17,6 +17,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; @@ -172,7 +173,7 @@ void testRootsListRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ROOTS_LIST, "test-id", null); + McpSchema.METHOD_ROOTS_LIST, McpId.of("test-id"), null); transport.simulateIncomingMessage(request); // Verify response @@ -180,7 +181,7 @@ void testRootsListRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()) .isEqualTo(new McpSchema.ListRootsResult(List.of(new Root("file:///test/path", "test-root")))); assertThat(response.error()).isNull(); @@ -309,7 +310,7 @@ void testSamplingCreateMessageRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, McpId.of("test-id"), messageRequest); transport.simulateIncomingMessage(request); // Verify response @@ -317,7 +318,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), @@ -350,7 +351,7 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, McpId.of("test-id"), messageRequest); transport.simulateIncomingMessage(request); // Verify error response @@ -358,7 +359,7 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()).isNull(); assertThat(response.error()).isNotNull(); assertThat(response.error().message()).contains("Method not found: sampling/createMessage"); @@ -414,7 +415,7 @@ void testElicitationCreateRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify response @@ -422,7 +423,7 @@ void testElicitationCreateRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { @@ -459,7 +460,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify response @@ -467,7 +468,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { @@ -498,7 +499,7 @@ void testElicitationCreateRequestHandlingWithoutCapability() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify error response @@ -506,7 +507,7 @@ void testElicitationCreateRequestHandlingWithoutCapability() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()).isNull(); assertThat(response.error()).isNotNull(); assertThat(response.error().message()).contains("Method not found: elicitation/create"); @@ -535,7 +536,7 @@ void testPingMessageRequestHandling() { // Simulate incoming ping request from server McpSchema.JSONRPCRequest pingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_PING, "ping-id", null); + McpSchema.METHOD_PING, McpId.of("ping-id"), null); transport.simulateIncomingMessage(pingRequest); // Verify response @@ -543,7 +544,7 @@ void testPingMessageRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("ping-id"); + assertThat(response.id().toString()).isEqualTo("ping-id"); assertThat(response.error()).isNull(); assertThat(response.result()).isInstanceOf(Map.class); assertThat(((Map) response.result())).isEmpty(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e4348be25..fdcb15933 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -106,7 +108,7 @@ void cleanup() { @Test void testMessageProcessing() { // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Simulate receiving the message @@ -137,7 +139,7 @@ void testResponseMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -161,7 +163,7 @@ void testErrorMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -191,7 +193,7 @@ void testGracefulShutdown() { StepVerifier.create(transport.closeGracefully()).verifyComplete(); // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message is not processed after shutdown @@ -236,10 +238,10 @@ void testMultipleMessageProcessing() { """); // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"), Map.of("key", "value1")); - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"), Map.of("key", "value2")); // Verify both messages are processed diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 39066a9a2..4ef60bb1a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -54,7 +54,7 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - exchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + exchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo, "test-transport"); } @Test @@ -65,7 +65,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(singlePageResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -93,11 +93,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -119,7 +119,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(emptyResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -139,7 +139,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(result)); StepVerifier.create(exchange.listRoots("someCursor")).assertNext(listResult -> { @@ -153,7 +153,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -174,11 +174,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -227,13 +227,15 @@ void testLoggingNotificationWithAllowedLevel() { .data("Test error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); // Verify that sendNotification was called exactly once - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport")); } @Test @@ -251,7 +253,8 @@ void testLoggingNotificationWithFilteredLevel() { StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification), + eq("test-transport")); } @Test @@ -269,7 +272,8 @@ void testLoggingNotificationLevelFiltering() { StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification), + eq("test-transport")); // Test INFO (should be filtered) McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() @@ -281,7 +285,8 @@ void testLoggingNotificationLevelFiltering() { StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport")); reset(mockSession); @@ -292,14 +297,15 @@ void testLoggingNotificationLevelFiltering() { .data("Warning message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(warningNotification)).verifyComplete(); // Verify that sendNotification was called exactly once for WARNING level verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(warningNotification)); + eq(warningNotification), eq("test-transport")); // Test ERROR (should be sent) McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() @@ -308,14 +314,15 @@ void testLoggingNotificationLevelFiltering() { .data("Error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(errorNotification)).verifyComplete(); // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification), + eq("test-transport")); } @Test @@ -327,13 +334,15 @@ void testLoggingNotificationWithDefaultLevel() { .data("Info message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport")); } @Test @@ -345,7 +354,8 @@ void testLoggingNotificationWithSessionError() { .data("Test error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session error"))); StepVerifier.create(exchange.loggingNotification(notification)).verifyErrorSatisfies(error -> { @@ -379,7 +389,8 @@ void testLoggingLevelHierarchy() { if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); @@ -398,7 +409,8 @@ void testLoggingLevelHierarchy() { @Test void testCreateElicitationWithNullCapabilities() { // Given - Create exchange with null capabilities - McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo, + "test-transport"); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") @@ -412,7 +424,7 @@ void testCreateElicitationWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -423,7 +435,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { .build(); McpAsyncServerExchange exchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutElicitation, clientInfo); + capabilitiesWithoutElicitation, clientInfo, "test-transport"); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") @@ -437,7 +449,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { // Verify that sendRequest was never called due to missing elicitation // capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -448,7 +460,7 @@ void testCreateElicitationWithComplexRequest() { .build(); McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); // Create a complex elicit request with schema java.util.Map requestedSchema = new java.util.HashMap<>(); @@ -472,7 +484,7 @@ void testCreateElicitationWithComplexRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -492,7 +504,7 @@ void testCreateElicitationWithDeclineAction() { .build(); McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide sensitive information") @@ -503,7 +515,7 @@ void testCreateElicitationWithDeclineAction() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -520,7 +532,7 @@ void testCreateElicitationWithCancelAction() { .build(); McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your information") @@ -531,7 +543,7 @@ void testCreateElicitationWithCancelAction() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -548,14 +560,14 @@ void testCreateElicitationWithSessionError() { .build(); McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { @@ -570,7 +582,8 @@ void testCreateElicitationWithSessionError() { @Test void testCreateMessageWithNullCapabilities() { - McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo, + "test-transport"); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -585,7 +598,7 @@ void testCreateMessageWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -596,7 +609,7 @@ void testCreateMessageWithoutSamplingCapabilities() { .build(); McpAsyncServerExchange exchangeWithoutSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutSampling, clientInfo); + capabilitiesWithoutSampling, clientInfo, "test-transport"); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -610,7 +623,7 @@ void testCreateMessageWithoutSamplingCapabilities() { // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -621,7 +634,7 @@ void testCreateMessageWithBasicRequest() { .build(); McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + clientInfo, "test-transport"); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -636,7 +649,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -657,7 +670,7 @@ void testCreateMessageWithImageContent() { .build(); McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + clientInfo, "test-transport"); // Create request with image content McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -674,7 +687,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -692,7 +705,7 @@ void testCreateMessageWithSessionError() { .build(); McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + clientInfo, "test-transport"); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -700,7 +713,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { @@ -716,7 +729,7 @@ void testCreateMessageWithIncludeContext() { .build(); McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + clientInfo, "test-transport"); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -732,7 +745,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -750,7 +763,8 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.just(expectedResponse)); StepVerifier.create(exchange.ping()).assertNext(result -> { @@ -759,14 +773,16 @@ void testPingWithSuccessfulResponse() { }).verifyComplete(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.error(mcpError)); // When & Then @@ -774,13 +790,15 @@ void testPingWithMcpError() { assertThat(error).isInstanceOf(McpError.class).hasMessage("Server unavailable"); }); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -795,7 +813,8 @@ void testPingMultipleCalls() { }).verifyComplete(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index f643f1ba3..2acce4d40 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -10,6 +10,8 @@ import io.modelcontextprotocol.MockMcpServerTransport; import io.modelcontextprotocol.MockMcpServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -23,7 +25,7 @@ class McpServerProtocolVersionTests { private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); - private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(String requestId, String protocolVersion) { + private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(McpId requestId, String protocolVersion) { return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, requestId, new McpSchema.InitializeRequest(protocolVersion, null, CLIENT_INFO)); } @@ -34,7 +36,7 @@ void shouldUseLatestVersionByDefault() { var transportProvider = new MockMcpServerTransportProvider(serverTransport); McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider .simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); @@ -60,7 +62,7 @@ void shouldNegotiateSpecificVersion() { server.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); @@ -83,7 +85,7 @@ void shouldSuggestLatestVersionForUnsupportedVersion() { McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); @@ -111,7 +113,8 @@ void shouldUseHighestVersionWhenMultipleSupported() { server.setProtocolVersions(List.of(oldVersion, middleVersion, latestVersion)); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index 66d7695e8..b6a013478 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -55,7 +55,7 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - asyncExchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + asyncExchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo, "test-transport"); exchange = new McpSyncServerExchange(asyncExchange); } @@ -67,7 +67,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(singlePageResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -95,11 +95,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -121,7 +121,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(emptyResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -141,7 +141,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(result)); McpSchema.ListRootsResult listResult = exchange.listRoots("someCursor"); @@ -155,7 +155,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -174,11 +174,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -226,13 +226,15 @@ void testLoggingNotificationWithAllowedLevel() { .data("Test error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.empty()); exchange.loggingNotification(notification); // Verify that sendNotification was called exactly once - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport")); } @Test @@ -250,7 +252,8 @@ void testLoggingNotificationWithFilteredLevel() { exchange.loggingNotification(debugNotification); // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification), + eq("test-transport")); } @Test @@ -268,7 +271,8 @@ void testLoggingNotificationLevelFiltering() { exchange.loggingNotification(debugNotification); // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification), + eq("test-transport")); // Test INFO (should be filtered) McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() @@ -280,7 +284,8 @@ void testLoggingNotificationLevelFiltering() { exchange.loggingNotification(infoNotification); // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport")); reset(mockSession); @@ -291,14 +296,15 @@ void testLoggingNotificationLevelFiltering() { .data("Warning message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); exchange.loggingNotification(warningNotification); // Verify that sendNotification was called exactly once for WARNING level verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(warningNotification)); + eq(warningNotification), eq("test-transport")); // Test ERROR (should be sent) McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() @@ -307,14 +313,15 @@ void testLoggingNotificationLevelFiltering() { .data("Error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); exchange.loggingNotification(errorNotification); // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification), + eq("test-transport")); } @Test @@ -326,13 +333,15 @@ void testLoggingNotificationWithDefaultLevel() { .data("Info message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport"))) .thenReturn(Mono.empty()); exchange.loggingNotification(infoNotification); // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification), + eq("test-transport")); } @Test @@ -344,7 +353,8 @@ void testLoggingNotificationWithSessionError() { .data("Test error message") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session error"))); assertThatThrownBy(() -> exchange.loggingNotification(notification)).isInstanceOf(RuntimeException.class) @@ -370,7 +380,8 @@ void testLoggingLevelHierarchy() { if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification), + eq("test-transport"))) .thenReturn(Mono.empty()); exchange.loggingNotification(notification); @@ -390,7 +401,7 @@ void testLoggingLevelHierarchy() { void testCreateElicitationWithNullCapabilities() { // Given - Create exchange with null capabilities McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, - clientInfo); + clientInfo, "test-transport"); McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( asyncExchangeWithNullCapabilities); @@ -404,7 +415,7 @@ void testCreateElicitationWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -415,7 +426,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { .build(); McpAsyncServerExchange asyncExchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutElicitation, clientInfo); + capabilitiesWithoutElicitation, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithoutElicitation = new McpSyncServerExchange(asyncExchangeWithoutElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -429,7 +440,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { // Verify that sendRequest was never called due to missing elicitation // capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -440,7 +451,7 @@ void testCreateElicitationWithComplexRequest() { .build(); McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); // Create a complex elicit request with schema @@ -465,7 +476,7 @@ void testCreateElicitationWithComplexRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -485,7 +496,7 @@ void testCreateElicitationWithDeclineAction() { .build(); McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -497,7 +508,7 @@ void testCreateElicitationWithDeclineAction() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -514,7 +525,7 @@ void testCreateElicitationWithCancelAction() { .build(); McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -526,7 +537,7 @@ void testCreateElicitationWithCancelAction() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -543,7 +554,7 @@ void testCreateElicitationWithSessionError() { .build(); McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + capabilitiesWithElicitation, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -551,7 +562,7 @@ void testCreateElicitationWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithElicitation.createElicitation(elicitRequest)) @@ -567,7 +578,7 @@ void testCreateElicitationWithSessionError() { void testCreateMessageWithNullCapabilities() { McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, - clientInfo); + clientInfo, "test-transport"); McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( asyncExchangeWithNullCapabilities); @@ -582,7 +593,7 @@ void testCreateMessageWithNullCapabilities() { // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -593,7 +604,7 @@ void testCreateMessageWithoutSamplingCapabilities() { .build(); McpAsyncServerExchange asyncExchangeWithoutSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutSampling, clientInfo); + capabilitiesWithoutSampling, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithoutSampling = new McpSyncServerExchange(asyncExchangeWithoutSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -607,7 +618,7 @@ void testCreateMessageWithoutSamplingCapabilities() { // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeReference.class), any()); } @Test @@ -618,7 +629,7 @@ void testCreateMessageWithBasicRequest() { .build(); McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + capabilitiesWithSampling, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -634,7 +645,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -655,7 +666,7 @@ void testCreateMessageWithImageContent() { .build(); McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + capabilitiesWithSampling, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); // Create request with image content @@ -673,7 +684,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -691,7 +702,7 @@ void testCreateMessageWithSessionError() { .build(); McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + capabilitiesWithSampling, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -700,7 +711,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithSampling.createMessage(createMessageRequest)) @@ -716,7 +727,7 @@ void testCreateMessageWithIncludeContext() { .build(); McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + capabilitiesWithSampling, clientInfo, "test-transport"); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -733,7 +744,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeReference.class), eq("test-transport"))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -751,32 +762,37 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.just(expectedResponse)); exchange.ping(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.error(mcpError)); // When & Then assertThatThrownBy(() -> exchange.ping()).isInstanceOf(McpError.class).hasMessage("Server unavailable"); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport"))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -787,7 +803,8 @@ void testPingMultipleCalls() { exchange.ping(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class), + eq("test-transport")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java new file mode 100644 index 000000000..13114e5c9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java new file mode 100644 index 000000000..568abd741 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpSyncServerTests extends AbstractMcpSyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java new file mode 100644 index 000000000..78a1ab15a --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java @@ -0,0 +1,349 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link StreamableHttpServerTransportProvider}. + */ +class StreamableHttpServerTransportProviderTests { + + private StreamableHttpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.StreamableHttpSessionFactory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.StreamableHttpSessionFactory.class); + + when(sessionFactory.create(anyString())).thenReturn(mockSession); + when(mockSession.getId()).thenReturn("test-session-id"); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + transportProvider = new StreamableHttpServerTransportProvider(objectMapper, "/mcp", null); + transportProvider.setStreamableHttpSessionFactory(sessionFactory); + } + + @Test + void shouldCreateSessionOnFirstRequest() { + // Test session creation directly through the getOrCreateSession method + String sessionId = "test-session-1"; + + McpServerSession session = transportProvider.getOrCreateSession(sessionId, true); + + assertThat(session).isNotNull(); + verify(sessionFactory).create(sessionId); + } + + @Test + void shouldHandleSSERequest() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + String sessionId = "test-session-2"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + + // First create a session + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doGet(request, response); + + verify(response).setContentType("text/event-stream"); + verify(response).setCharacterEncoding("UTF-8"); + verify(response).setHeader("Cache-Control", "no-cache"); + verify(response).setHeader("Connection", "keep-alive"); + } + + @Test + void shouldNotifyClients() { + String sessionId = "test-session-3"; + transportProvider.getOrCreateSession(sessionId, true); + + String method = "test/notification"; + String params = "test message"; + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Verify that the session was created + assertThat(transportProvider.getOrCreateSession(sessionId, false)).isNotNull(); + } + + @Test + void shouldCloseGracefully() { + String sessionId = "test-session-4"; + transportProvider.getOrCreateSession(sessionId, true); + + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleInvalidRequestURI() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + when(request.getRequestURI()).thenReturn("/wrong-path"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + transportProvider.doGet(request, response); + transportProvider.doPost(request, response); + transportProvider.doDelete(request, response); + + verify(response, times(3)).sendError(HttpServletResponse.SC_NOT_FOUND); + } + + @Test + void shouldRejectNonJSONContentType() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("POST"); + when(request.getHeader("Content-Type")).thenReturn("text/plain"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doPost(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 415 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRejectInvalidAcceptHeader() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/html"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 406 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRequireSessionIdForSSE() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(null); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleSessionCleanup() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + String sessionId = "test-session-5"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + // Create a session first + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doDelete(request, response); + + verify(response).setStatus(HttpServletResponse.SC_OK); + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleDeleteNonExistentSession() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("non-existent-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doDelete(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 404 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleMultipleSessions() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + // Create separate mock sessions for each ID + McpServerSession mockSession1 = mock(McpServerSession.class); + McpServerSession mockSession2 = mock(McpServerSession.class); + when(mockSession1.getId()).thenReturn(sessionId1); + when(mockSession2.getId()).thenReturn(sessionId2); + when(mockSession1.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession2.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession1.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + when(mockSession2.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + // Configure factory to return different sessions for different IDs + when(sessionFactory.create(sessionId1)).thenReturn(mockSession1); + when(sessionFactory.create(sessionId2)).thenReturn(mockSession2); + + McpServerSession session1 = transportProvider.getOrCreateSession(sessionId1, true); + McpServerSession session2 = transportProvider.getOrCreateSession(sessionId2, true); + + assertThat(session1).isNotNull(); + assertThat(session2).isNotNull(); + assertThat(session1).isNotSameAs(session2); + + // Verify both sessions are created with different IDs + verify(sessionFactory, times(2)).create(anyString()); + } + + @Test + void shouldReuseExistingSession() { + String sessionId = "test-session-6"; + + McpServerSession session1 = transportProvider.getOrCreateSession(sessionId, true); + McpServerSession session2 = transportProvider.getOrCreateSession(sessionId, false); + + assertThat(session1).isSameAs(session2); + verify(sessionFactory, times(1)).create(sessionId); + } + + @Test + void shouldHandleAsyncTimeout() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + + transportProvider.getOrCreateSession("test-session", true); + transportProvider.doGet(request, response); + + verify(asyncContext).setTimeout(0L); // Updated to match actual implementation + } + + @Test + void shouldBuildWithCustomConfiguration() { + ObjectMapper customMapper = new ObjectMapper(); + String customEndpoint = "/custom-mcp"; + + StreamableHttpServerTransportProvider provider = StreamableHttpServerTransportProvider.builder() + .withObjectMapper(customMapper) + .withMcpEndpoint(customEndpoint) + .withSessionIdProvider(() -> "custom-session-id") + .build(); + + assertThat(provider).isNotNull(); + } + + @Test + void shouldHandleBuilderValidation() { + try { + StreamableHttpServerTransportProvider.builder().withObjectMapper(null).build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("ObjectMapper must not be null"); + } + + try { + StreamableHttpServerTransportProvider.builder().withMcpEndpoint("").build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("MCP endpoint must not be empty"); + } + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e0..4c7bfbc0c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -9,6 +9,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -144,7 +146,7 @@ void testRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); + McpId.of("test-id"), echoMessage); transport.simulateIncomingMessage(request); // Verify response @@ -179,7 +181,7 @@ void testNotificationHandling() { void testUnknownMethodHandling() { // Simulate incoming request for unknown method McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); + McpId.of("test-id"), null); transport.simulateIncomingMessage(request); // Verify error response diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ea063e4e3..782a4cf4c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; import net.javacrumbs.jsonunit.core.Option; @@ -240,8 +241,8 @@ void testJSONRPCRequest() throws Exception { Map params = new HashMap<>(); params.put("key", "value"); - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, - params); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", + McpId.of(1), params); String value = mapper.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) @@ -272,7 +273,8 @@ void testJSONRPCResponse() throws Exception { Map result = new HashMap<>(); result.put("result_key", "result_value"); - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, McpId.of(1), + result, null); String value = mapper.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) @@ -287,7 +289,8 @@ void testJSONRPCResponseWithError() throws Exception { McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, McpId.of(1), null, + error); String value = mapper.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER)