prompts,
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java
new file mode 100644
index 000000000..5a4331cf5
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+/**
+ * Handler interface for session lifecycle and runtime events in the Streamable HTTP
+ * transport.
+ *
+ *
+ * This interface provides hooks for monitoring and responding to various session-related
+ * events that occur during the operation of the HTTP-based MCP server transport.
+ * Implementations can use these callbacks to:
+ *
+ * - Log session activities
+ * - Implement custom session management logic
+ * - Handle error conditions
+ * - Perform cleanup operations
+ *
+ *
+ * @author Zachary German
+ */
+public interface SessionHandler {
+
+ /**
+ * Called when a new session is created.
+ * @param sessionId The ID of the newly created session
+ * @param context Additional context information (may be null)
+ */
+ void onSessionCreate(String sessionId, Object context);
+
+ /**
+ * Called when a session is closed.
+ * @param sessionId The ID of the closed session
+ */
+ void onSessionClose(String sessionId);
+
+ /**
+ * Called when a session is not found for a given session ID.
+ * @param sessionId The session ID that was not found
+ * @param request The HTTP request that referenced the missing session
+ * @param response The HTTP response that will be sent to the client
+ */
+ void onSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response);
+
+ /**
+ * Called when an error occurs while sending a notification to a session.
+ * @param sessionId The ID of the session where the error occurred
+ * @param error The error that occurred
+ */
+ void onSendNotificationError(String sessionId, Throwable error);
+
+}
\ No newline at end of file
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java
new file mode 100644
index 000000000..d32038b78
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java
@@ -0,0 +1,893 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.spec.McpServerSession;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
+import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse;
+import io.modelcontextprotocol.spec.McpServerSession;
+import io.modelcontextprotocol.spec.McpServerTransport;
+import io.modelcontextprotocol.spec.McpServerTransportProvider;
+import io.modelcontextprotocol.spec.SseEvent;
+import io.modelcontextprotocol.util.Assert;
+import jakarta.servlet.AsyncContext;
+import jakarta.servlet.ReadListener;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.ServletInputStream;
+import jakarta.servlet.annotation.WebServlet;
+import jakarta.servlet.http.HttpServlet;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import java.nio.charset.StandardCharsets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.util.context.Context;
+
+import static java.util.Objects.requireNonNullElse;
+
+/**
+ * MCP Streamable HTTP transport provider that uses a single session class to manage all
+ * streams and transports.
+ *
+ *
+ * Key improvements over the original implementation:
+ *
+ * - Manages server-client sessions, including transport registration.
+ *
- Handles HTTP requests, providing direct-HTTP or SSE-streamed responses.
+ *
- Provides callbacks for session lifecycle and errors.
+ *
- Enforces allowed 'Origin' header values if configured.
+ *
- Provides a default constructor and default values for all constructor parameters.
+ *
+ *
+ * @author Zachary German
+ */
+@WebServlet(asyncSupported = true)
+public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class);
+
+ public static final String UTF_8 = "UTF-8";
+
+ public static final String APPLICATION_JSON = "application/json";
+
+ public static final String TEXT_EVENT_STREAM = "text/event-stream";
+
+ public static final String SESSION_ID_HEADER = "Mcp-Session-Id";
+
+ public static final String LAST_EVENT_ID_HEADER = "Last-Event-Id";
+
+ public static final String MESSAGE_EVENT_TYPE = "message";
+
+ public static final String ACCEPT_HEADER = "Accept";
+
+ public static final String ORIGIN_HEADER = "Origin";
+
+ public static final String ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin";
+
+ public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*";
+
+ public static final String PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version";
+
+ public static final String CACHE_CONTROL_HEADER = "Cache-Control";
+
+ public static final String CONNECTION_HEADER = "Connection";
+
+ public static final String CACHE_CONTROL_NO_CACHE = "no-cache";
+
+ public static final String CONNECTION_KEEP_ALIVE = "keep-alive";
+
+ public static final String MCP_SESSION_ID = "MCP-Session-ID";
+
+ public static final String DEFAULT_MCP_ENDPOINT = "/mcp";
+
+ /** com.fasterxml.jackson.databind.ObjectMapper */
+ private static final ObjectMapper DEFAULT_OBJECT_MAPPER = new ObjectMapper();
+
+ /** UUID.randomUUID().toString() */
+ private static final Supplier DEFAULT_SESSION_ID_PROVIDER = () -> UUID.randomUUID().toString();
+
+ /** JSON object mapper for serialization/deserialization */
+ private final ObjectMapper objectMapper;
+
+ /** The endpoint path for handling MCP requests */
+ private final String mcpEndpoint;
+
+ /** Supplier for generating unique session IDs */
+ private final Supplier sessionIdProvider;
+
+ /** Sessions map, keyed by Session ID */
+ private static final Map sessions = new ConcurrentHashMap<>();
+
+ /** Flag indicating if the transport is in the process of shutting down */
+ private final AtomicBoolean isClosing = new AtomicBoolean(false);
+
+ /** Optional allowed 'Origin' header value list. Not enforced if empty. */
+ private final List allowedOrigins = new ArrayList<>();
+
+ /** Callback interface for session lifecycle and errors */
+ private SessionHandler sessionHandler;
+
+ /** Factory for McpServerSession takes session IDs */
+ private McpServerSession.StreamableHttpSessionFactory streamableHttpSessionFactory;
+
+ /**
+ *
+ * - Manages server-client sessions, including transport registration.
+ *
- Handles HTTP requests and HTTP/SSE responses and streams.
+ *
+ * @param objectMapper ObjectMapper - Default:
+ * com.fasterxml.jackson.databind.ObjectMapper
+ * @param mcpEndpoint String - Default: '/mcp'
+ * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString()
+ */
+ public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
+ Supplier sessionIdProvider) {
+ this.objectMapper = requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER);
+ this.mcpEndpoint = requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT);
+ this.sessionIdProvider = requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER);
+ }
+
+ /**
+ *
+ * - Manages server-client sessions, including transport registration.
+ *
- Handles HTTP requests and HTTP/SSE responses and streams.
+ *
+ * @param objectMapper ObjectMapper - Default:
+ * com.fasterxml.jackson.databind.ObjectMapper
+ * @param mcpEndpoint String - Default: '/mcp'
+ * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString()
+ */
+ public StreamableHttpServerTransportProvider() {
+ this(null, null, null);
+ }
+
+ @Override
+ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
+ // Required but not used for this implementation
+ }
+
+ public void setStreamableHttpSessionFactory(McpServerSession.StreamableHttpSessionFactory sessionFactory) {
+ this.streamableHttpSessionFactory = sessionFactory;
+ }
+
+ public void setSessionHandler(SessionHandler sessionHandler) {
+ this.sessionHandler = sessionHandler;
+ }
+
+ public void setAllowedOrigins(List allowedOrigins) {
+ this.allowedOrigins.clear();
+ this.allowedOrigins.addAll(allowedOrigins);
+ }
+
+ @Override
+ public Mono notifyClients(String method, Object params) {
+ if (sessions.isEmpty()) {
+ logger.debug("No active sessions to broadcast message to");
+ return Mono.empty();
+ }
+
+ logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
+
+ return Flux.fromIterable(sessions.values())
+ .filter(session -> session.getState() == session.STATE_INITIALIZED)
+ .flatMap(session -> session.sendNotification(method, params).doOnError(e -> {
+ logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage());
+ if (sessionHandler != null) {
+ sessionHandler.onSendNotificationError(session.getId(), e);
+ }
+ }).onErrorComplete())
+ .then();
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.defer(() -> {
+ isClosing.set(true);
+ logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
+ return Flux.fromIterable(sessions.values())
+ .flatMap(session -> session.closeGracefully()
+ .doOnError(e -> logger.error("Error closing session {}: {}", session.getId(), e.getMessage()))
+ .onErrorComplete())
+ .then();
+ });
+ }
+
+ @Override
+ protected void doGet(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ logger.info("GET request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request));
+
+ if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response)
+ || !validateNotClosing(response)) {
+ return;
+ }
+
+ String acceptHeader = request.getHeader(ACCEPT_HEADER);
+ if (acceptHeader == null || !acceptHeader.contains(TEXT_EVENT_STREAM)) {
+ logger.debug("Accept header missing or does not include {}", TEXT_EVENT_STREAM);
+ sendErrorResponse(response, "Accept header must include text/event-stream");
+ return;
+ }
+
+ String sessionId = request.getHeader(SESSION_ID_HEADER);
+ if (sessionId == null) {
+ sendErrorResponse(response, "Session ID missing in request header");
+ return;
+ }
+ else {
+ response.setHeader(SESSION_ID_HEADER, sessionId);
+ }
+
+ McpServerSession session = sessions.get(sessionId);
+ if (session == null) {
+ handleSessionNotFound(sessionId, request, response);
+ return;
+ }
+
+ // Delayed until version negotiation is implemented.
+ /*
+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
+ * sendErrorResponse(response, "Protocol version missing in request header"); }
+ */
+
+ AsyncContext asyncContext = request.startAsync();
+ asyncContext.setTimeout(0);
+
+ String lastEventId = request.getHeader(LAST_EVENT_ID_HEADER);
+
+ if (lastEventId == null) { // Just opening a listening stream
+ SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId,
+ session.LISTENING_TRANSPORT, sessionId);
+ logger.debug("Registered SSE transport {} for session {}", session.LISTENING_TRANSPORT, sessionId);
+ }
+ else { // Asking for a stream to replay events from a previous request
+ SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId,
+ request.getRequestId(), sessionId);
+ logger.debug("Registered SSE transport {} for session {}", request.getRequestId(), sessionId);
+ }
+ }
+
+ @Override
+ protected void doPost(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ logger.info("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request));
+
+ if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response)
+ || !validateNotClosing(response)) {
+ return;
+ }
+
+ String acceptHeader = request.getHeader(ACCEPT_HEADER);
+ if (acceptHeader == null
+ || (!acceptHeader.contains(APPLICATION_JSON) || !acceptHeader.contains(TEXT_EVENT_STREAM))) {
+ logger.debug("Accept header validation failed. Header: {}", acceptHeader);
+ sendErrorResponse(response, "Accept header must include both application/json and text/event-stream");
+ return;
+ }
+
+ AsyncContext asyncContext = request.startAsync();
+ asyncContext.setTimeout(0);
+
+ StringBuilder body = new StringBuilder();
+ ServletInputStream inputStream = request.getInputStream();
+
+ inputStream.setReadListener(new ReadListener() {
+ @Override
+ public void onDataAvailable() throws IOException {
+ int len;
+ byte[] buffer = new byte[1024];
+ while (inputStream.isReady() && (len = inputStream.read(buffer)) != -1) {
+ body.append(new String(buffer, 0, len, StandardCharsets.UTF_8));
+ }
+ }
+
+ @Override
+ public void onAllDataRead() throws IOException {
+ try {
+ logger.debug("Parsing JSON-RPC message: {}", body.toString());
+ JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());
+
+ boolean isInitializeRequest = false;
+ String sessionId = request.getHeader(SESSION_ID_HEADER);
+
+ if (message instanceof McpSchema.JSONRPCRequest req
+ && McpSchema.METHOD_INITIALIZE.equals(req.method())) {
+ isInitializeRequest = true;
+ logger.debug("Detected initialize request");
+ if (sessionId == null) {
+ sessionId = sessionIdProvider.get();
+ logger.debug("Created new session ID for initialize request: {}", sessionId);
+ }
+ }
+
+ if (!isInitializeRequest && sessionId == null) {
+ sendErrorResponse(response, "Session ID missing in request header");
+ asyncContext.complete();
+ return;
+ }
+
+ McpServerSession session = getOrCreateSession(sessionId, isInitializeRequest);
+ if (session == null) {
+ logger.error("Failed to create session for sessionId: {}", sessionId);
+ handleSessionNotFound(sessionId, request, response);
+ asyncContext.complete();
+ return;
+ }
+
+ // Delayed until version negotiation is implemented.
+ /*
+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
+ * sendErrorResponse(response,
+ * "Protocol version missing in request header"); }
+ */
+
+ logger.debug("Using session: {}", sessionId);
+
+ response.setHeader(SESSION_ID_HEADER, sessionId);
+
+ final String transportId;
+ if (message instanceof JSONRPCRequest req) {
+ transportId = req.id().toString();
+ }
+ else if (message instanceof JSONRPCResponse resp) {
+ transportId = resp.id().toString();
+ }
+ else {
+ transportId = null;
+ }
+
+ // Only set content type for requests
+ if (message instanceof McpSchema.JSONRPCRequest) {
+ response.setContentType(APPLICATION_JSON);
+ }
+
+ if (transportId != null) { // Not needed for notifications (null
+ // transportId)
+ HttpTransport httpTransport = new HttpTransport(objectMapper, response, asyncContext);
+ session.registerTransport(transportId, httpTransport);
+ }
+
+ // For notifications, we need to handle the HTTP response manually
+ // since no JSON response is sent
+ if (message instanceof McpSchema.JSONRPCNotification) {
+ session.handle(message).doOnSuccess(v -> {
+ logger.debug("[NOTIFICATION] Sending empty HTTP response for notification");
+ try {
+ if (!response.isCommitted()) {
+ response.setStatus(HttpServletResponse.SC_OK);
+ response.setCharacterEncoding("UTF-8");
+ }
+ asyncContext.complete();
+ }
+ catch (Exception e) {
+ logger.error("Failed to send notification response: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ }).doOnError(e -> {
+ logger.error("Error in message handling: {}", e.getMessage(), e);
+ asyncContext.complete();
+ }).contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe();
+ }
+ else {
+ // For requests, let the transport handle the response
+ session.handle(message)
+ .doOnSuccess(v -> logger.info("Message handling completed successfully for transport: {}",
+ transportId))
+ .doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e))
+ .doFinally(signalType -> {
+ logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType);
+ session.unregisterTransport(transportId);
+ })
+ .contextWrite(Context.of(MCP_SESSION_ID, sessionId))
+ .subscribe(null, error -> {
+ logger.error("Error in message handling chain: {}", error.getMessage(), error);
+ asyncContext.complete();
+ });
+ }
+
+ }
+ catch (Exception e) {
+ logger.error("Error processing message: {}", e.getMessage());
+ sendErrorResponse(response, "Invalid JSON-RPC message: " + e.getMessage());
+ asyncContext.complete();
+ }
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ logger.error("Error reading request body: {}", t.getMessage());
+ try {
+ sendErrorResponse(response, "Error reading request: " + t.getMessage());
+ }
+ catch (IOException e) {
+ logger.error("Failed to write error response", e);
+ }
+ asyncContext.complete();
+ }
+ });
+ }
+
+ @Override
+ protected void doDelete(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ if (!requestURI.endsWith(mcpEndpoint)) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ String sessionId = request.getHeader(SESSION_ID_HEADER);
+ if (sessionId == null) {
+ sendErrorResponse(response, "Session ID missing in request header");
+ return;
+ }
+ else {
+ response.setHeader(SESSION_ID_HEADER, sessionId);
+ }
+
+ McpServerSession session = sessions.remove(sessionId);
+ if (session == null) {
+ handleSessionNotFound(sessionId, request, response);
+ return;
+ }
+
+ session.closeGracefully().contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe();
+ logger.debug("Session closed via DELETE request: {}", sessionId);
+ if (sessionHandler != null) {
+ sessionHandler.onSessionClose(sessionId);
+ }
+
+ response.setStatus(HttpServletResponse.SC_OK);
+ }
+
+ private boolean validateOrigin(HttpServletRequest request, HttpServletResponse response) throws IOException {
+ if (!allowedOrigins.isEmpty()) {
+ String origin = request.getHeader(ORIGIN_HEADER);
+ if (!allowedOrigins.contains(origin)) {
+ logger.debug("Origin header does not match allowed origins: '{}'", origin);
+ response.sendError(HttpServletResponse.SC_FORBIDDEN);
+ return false;
+ }
+ else {
+ response.setHeader(ALLOW_ORIGIN_HEADER, origin);
+ }
+ }
+ else {
+ response.setHeader(ALLOW_ORIGIN_HEADER, ALLOW_ORIGIN_DEFAULT_VALUE);
+ }
+ return true;
+ }
+
+ private boolean validateEndpoint(String requestURI, HttpServletResponse response) throws IOException {
+ if (!requestURI.endsWith(mcpEndpoint)) {
+ logger.debug("URI does not match MCP endpoint: '{}'", mcpEndpoint);
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return false;
+ }
+ return true;
+ }
+
+ private boolean validateNotClosing(HttpServletResponse response) throws IOException {
+ if (isClosing.get()) {
+ logger.debug("Server is shutting down, rejecting request");
+ response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
+ return false;
+ }
+ return true;
+ }
+
+ protected McpServerSession getOrCreateSession(String sessionId, boolean createIfMissing) {
+ McpServerSession session = sessions.get(sessionId);
+ logger.debug("Looking for session: {}, found: {}", sessionId, session != null);
+ if (session == null && createIfMissing) {
+ logger.debug("Creating new session: {}", sessionId);
+ session = streamableHttpSessionFactory.create(sessionId);
+ session.setIsStreamableHttp(true); // TODO: Remove this if we split SHTTP server
+ // session to its own class.
+ sessions.put(sessionId, session);
+ logger.debug("Created new session: {}", sessionId);
+ if (sessionHandler != null) {
+ sessionHandler.onSessionCreate(sessionId, null);
+ }
+ }
+ return session;
+ }
+
+ private void handleSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+ sendErrorResponse(response, "Session not found: " + sessionId);
+ if (sessionHandler != null) {
+ sessionHandler.onSessionNotFound(sessionId, request, response);
+ }
+ }
+
+ private void sendErrorResponse(HttpServletResponse response, String message) throws IOException {
+ response.setContentType(APPLICATION_JSON);
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
+ response.getWriter().write(createErrorJson(message));
+ }
+
+ private String createErrorJson(String message) {
+ try {
+ return objectMapper.writeValueAsString(new McpError(message));
+ }
+ catch (IOException e) {
+ logger.error("Failed to serialize error message", e);
+ return "{\"error\":\"" + message + "\"}";
+ }
+ }
+
+ @Override
+ public void destroy() {
+ closeGracefully().block();
+ super.destroy();
+ }
+
+ private Map extractHeaders(HttpServletRequest request) {
+ Map headers = new HashMap<>();
+ Enumeration headerNames = request.getHeaderNames();
+ while (headerNames.hasMoreElements()) {
+ String name = headerNames.nextElement();
+ headers.put(name, request.getHeader(name));
+ }
+ return headers;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private ObjectMapper objectMapper = DEFAULT_OBJECT_MAPPER;
+
+ private String mcpEndpoint = DEFAULT_MCP_ENDPOINT;
+
+ private Supplier sessionIdProvider = DEFAULT_SESSION_ID_PROVIDER;
+
+ public Builder withObjectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ public Builder withMcpEndpoint(String mcpEndpoint) {
+ Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty");
+ this.mcpEndpoint = mcpEndpoint;
+ return this;
+ }
+
+ public Builder withSessionIdProvider(Supplier sessionIdProvider) {
+ Assert.notNull(sessionIdProvider, "SessionIdProvider must not be null");
+ this.sessionIdProvider = sessionIdProvider;
+ return this;
+ }
+
+ public StreamableHttpServerTransportProvider build() {
+ return new StreamableHttpServerTransportProvider(objectMapper, mcpEndpoint, sessionIdProvider);
+ }
+
+ }
+
+ private enum ResponseType {
+
+ IMMEDIATE, STREAM
+
+ }
+
+ /**
+ * SSE transport implementation.
+ */
+ public static class SseTransport implements McpServerTransport {
+
+ private static final Logger logger = LoggerFactory.getLogger(SseTransport.class);
+
+ private final ObjectMapper objectMapper;
+
+ private final HttpServletResponse response;
+
+ private final AsyncContext asyncContext;
+
+ private final Sinks.Many eventSink = Sinks.many().unicast().onBackpressureBuffer();
+
+ private final Map eventHistory = new ConcurrentHashMap<>();
+
+ private final String id;
+
+ private final String sessionId;
+
+ public SseTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext,
+ String lastEventId, String transportId, String sessionId) {
+ this.objectMapper = objectMapper;
+ this.response = response;
+ this.asyncContext = asyncContext;
+ this.id = transportId;
+ this.sessionId = sessionId;
+
+ response.setContentType(TEXT_EVENT_STREAM);
+ response.setCharacterEncoding(UTF_8);
+ response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE);
+ response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE);
+
+ logger.debug("Establishing SSE stream with ID: {} for session: {}", transportId, sessionId);
+ setupSseStream(lastEventId);
+
+ logger.debug("Registering SSE transport with ID: {} for session: {}", transportId, sessionId);
+ sessions.get(sessionId).registerTransport(transportId, this);
+ }
+
+ private void setupSseStream(String lastEventId) {
+ try {
+ PrintWriter writer = response.getWriter();
+
+ eventSink.asFlux().doOnNext(event -> {
+ try {
+ if (event.id() != null) {
+ writer.write("id: " + event.id() + "\n");
+ }
+ if (event.event() != null) {
+ writer.write("event: " + event.event() + "\n");
+ }
+ writer.write("data: " + event.data() + "\n\n");
+ writer.flush();
+
+ if (writer.checkError()) {
+ throw new IOException("Client disconnected");
+ }
+ }
+ catch (IOException e) {
+ logger.debug("Error writing to SSE stream: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ }).doOnComplete(() -> {
+ try {
+ writer.close();
+ }
+ finally {
+ asyncContext.complete();
+ }
+ }).doOnError(e -> {
+ logger.error("Error in SSE stream: {}", e.getMessage());
+ asyncContext.complete();
+ }).contextWrite(Context.of(MCP_SESSION_ID, response.getHeader(SESSION_ID_HEADER))).subscribe();
+
+ // Replay events if requested
+ if (lastEventId != null) {
+ replayEventsAfter(lastEventId);
+ }
+
+ }
+ catch (IOException e) {
+ logger.error("Failed to setup SSE stream: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ }
+
+ private void replayEventsAfter(String lastEventId) {
+ try {
+ McpServerSession session = sessions.get(sessionId);
+ String transportIdOfLastEventId = session.getTransportIdForEvent(lastEventId);
+ Map transportEventHistory = session
+ .getTransportEventHistory(transportIdOfLastEventId);
+ List eventIds = transportEventHistory.keySet()
+ .stream()
+ .map(Long::parseLong)
+ .filter(key -> key > Long.parseLong(lastEventId))
+ .sorted()
+ .map(String::valueOf)
+ .collect(Collectors.toList());
+ for (String eventId : eventIds) {
+ SseEvent event = transportEventHistory.get(eventId);
+ if (event != null) {
+ eventSink.tryEmitNext(event);
+ }
+ }
+ logger.debug("Completing SSE stream after replaying events");
+ eventSink.tryEmitComplete();
+ }
+ catch (NumberFormatException e) {
+ logger.warn("Invalid last event ID: {}", lastEventId);
+ }
+ }
+
+ @Override
+ public Mono sendMessage(JSONRPCMessage message) {
+ try {
+ String jsonText = objectMapper.writeValueAsString(message);
+ String eventId = sessions.get(sessionId).incrementAndGetEventId(id);
+ SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText);
+
+ eventHistory.put(eventId, event);
+ logger.debug("Sending SSE event {}: {}", eventId, jsonText);
+ eventSink.tryEmitNext(event);
+
+ if (message instanceof McpSchema.JSONRPCResponse) {
+ logger.debug("Completing SSE stream after sending response");
+ eventSink.tryEmitComplete();
+ McpServerSession session = sessions.get(sessionId);
+ if (session != null) {
+ session.setTransportEventHistory(id, eventHistory);
+ }
+ }
+
+ return Mono.empty();
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message: {}", e.getMessage());
+ return Mono.error(e);
+ }
+ }
+
+ /**
+ * Sends a stream of messages for Flux responses.
+ */
+ public Mono sendMessageStream(Flux messageStream) {
+ return messageStream.doOnNext(message -> {
+ try {
+ String jsonText = objectMapper.writeValueAsString(message);
+ String eventId = sessions.get(sessionId).incrementAndGetEventId(id);
+ SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText);
+
+ eventHistory.put(eventId, event);
+ logger.debug("Sending SSE stream event {}: {}", eventId, jsonText);
+ eventSink.tryEmitNext(event);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send stream message: {}", e.getMessage());
+ eventSink.tryEmitError(e);
+ }
+ }).doOnComplete(() -> {
+ logger.debug("Completing SSE stream after sending all stream messages");
+ eventSink.tryEmitComplete();
+ McpServerSession session = sessions.get(sessionId);
+ if (session != null) {
+ session.setTransportEventHistory(id, eventHistory);
+ }
+ }).then();
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ eventSink.tryEmitComplete();
+ McpServerSession session = sessions.get(sessionId);
+ if (session != null) {
+ session.setTransportEventHistory(id, eventHistory);
+ }
+ logger.debug("SSE transport closed gracefully");
+ });
+ }
+
+ }
+
+ /**
+ * HTTP transport implementation for immediate responses.
+ */
+ public static class HttpTransport implements McpServerTransport {
+
+ private static final Logger logger = LoggerFactory.getLogger(HttpTransport.class);
+
+ private final ObjectMapper objectMapper;
+
+ private final HttpServletResponse response;
+
+ private final AsyncContext asyncContext;
+
+ public HttpTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext) {
+ this.objectMapper = objectMapper;
+ this.response = response;
+ this.asyncContext = asyncContext;
+ }
+
+ public ObjectMapper getObjectMapper() {
+ return this.objectMapper;
+ }
+
+ public HttpServletResponse getResponse() {
+ return this.response;
+ }
+
+ public AsyncContext getAsyncContext() {
+ return this.asyncContext;
+ }
+
+ @Override
+ public Mono sendMessage(JSONRPCMessage message) {
+ return Mono.fromRunnable(() -> {
+ try {
+ if (response.isCommitted()) {
+ logger.warn("Response already committed, cannot send message");
+ return;
+ }
+
+ response.setCharacterEncoding("UTF-8");
+ response.setStatus(HttpServletResponse.SC_OK);
+
+ // For notifications, don't write any content (empty response)
+ if (message instanceof McpSchema.JSONRPCNotification) {
+ logger.debug("Sending empty 200 response for notification");
+ // Just complete the response with no content
+ }
+ else {
+ // For requests/responses, write JSON content
+ String jsonText = objectMapper.writeValueAsString(message);
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonText);
+ writer.flush();
+ logger.debug("Successfully sent immediate response: {}", jsonText);
+ }
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message: {}", e.getMessage(), e);
+ try {
+ if (!response.isCommitted()) {
+ response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
+ }
+ }
+ catch (Exception ignored) {
+ }
+ }
+ finally {
+ asyncContext.complete();
+ }
+ });
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ try {
+ asyncContext.complete();
+ }
+ catch (Exception e) {
+ logger.debug("Error completing async context: {}", e.getMessage());
+ }
+ logger.debug("HTTP transport closed gracefully");
+ });
+ }
+
+ }
+
+}
\ No newline at end of file
diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
index cc7d2abf8..454f1fc4b 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
@@ -5,6 +5,8 @@
package io.modelcontextprotocol.spec;
import com.fasterxml.jackson.core.type.TypeReference;
+
+import io.modelcontextprotocol.spec.McpSchema.McpId;
import io.modelcontextprotocol.util.Assert;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
@@ -47,7 +49,7 @@ public class McpClientSession implements McpSession {
private final McpClientTransport transport;
/** Map of pending responses keyed by request ID */
- private final ConcurrentHashMap