From 2c8894ef8bd2b59c6ee19a75d694544c39d0662b Mon Sep 17 00:00:00 2001 From: Zach Tang Date: Fri, 4 Jul 2025 13:40:12 -0700 Subject: [PATCH 1/6] feat: add streamable http client --- .../client/StreamableHttpClientTransport.kt | 179 ++++++++++++++++++ .../StreamableHttpMcpKtorClientExtensions.kt | 41 ++++ 2 files changed, 220 insertions(+) create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt new file mode 100644 index 0000000..d7f42ca --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -0,0 +1,179 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.ClientSSESession +import io.ktor.client.plugins.sse.sseSession +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.append +import io.ktor.http.contentType +import io.ktor.http.isSuccess +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi + +private val logger = KotlinLogging.logger {} + +/** + * Client transport for Streamable HTTP: this will send messages via HTTP POST requests + * and optionally receive streaming responses via SSE. + * + * This implements the Streamable HTTP transport as specified in MCP 2024-11-05. + */ +@OptIn(ExperimentalAtomicApi::class) +public class StreamableHttpClientTransport( + private val client: HttpClient, + private val url: String, + private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, +) : AbstractTransport() { + + private val initialized: AtomicBoolean = AtomicBoolean(false) + private var sseSession: ClientSSESession? = null + private val scope by lazy { CoroutineScope(SupervisorJob()) } + private var sseJob: Job? = null + private var sessionId: String? = null + + override suspend fun start() { + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { + error("StreamableHttpClientTransport already started!") + } + logger.debug { "Client transport starting..." } + startSseSession() + } + + private suspend fun startSseSession() { + logger.debug { "Client attempting to start SSE session at url: $url" } + try { + sseSession = client.sseSession( + urlString = url, + block = requestBuilder, + ) + logger.debug { "Client SSE session started successfully." } + + sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { + sseSession?.incoming?.collect { event -> + logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}" } + when (event.event) { + "error" -> { + val e = IllegalStateException("SSE error: ${event.data}") + logger.error(e) { "SSE stream reported an error event." } + _onError(e) + } + + else -> { + // All non-error events are treated as JSON-RPC messages + try { + val eventData = event.data + if (!eventData.isNullOrEmpty()) { + val message = McpJson.decodeFromString(eventData) + _onMessage(message) + } + } catch (e: Exception) { + logger.error(e) { "Error processing SSE message" } + _onError(e) + } + } + } + } + } + } catch (e: Exception) { + // SSE session is optional, don't fail if it can't be established + // The server might not support GET requests for SSE + logger.warn(e) { "Client failed to start SSE session. This may be expected if the server does not support GET." } + _onError(e) + } + } + + override suspend fun send(message: JSONRPCMessage) { + logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } + try { + val response = client.post(url) { + requestBuilder() + contentType(ContentType.Application.Json) + headers.append(HttpHeaders.Accept, "${ContentType.Application.Json}, ${ContentType.Text.EventStream}") + + // Add session ID if we have one + sessionId?.let { + headers.append("Mcp-Session-Id", it) + } + + setBody(McpJson.encodeToString(message)) + } + logger.debug { "Client received POST response: ${response.status}" } + + if (!response.status.isSuccess()) { + val text = response.bodyAsText() + val error = Exception("HTTP ${response.status}: $text") + logger.error(error) { "Client POST request failed." } + _onError(error) + throw error + } + + // Extract session ID from response headers if present + response.headers["Mcp-Session-Id"]?.let { + sessionId = it + } + + // Handle response based on content type + when (response.contentType()?.contentType) { + ContentType.Application.Json.contentType -> { + // Single JSON response + val responseBody = response.bodyAsText() + logger.trace { "Client processing JSON response: $responseBody" } + if (responseBody.isNotEmpty()) { + try { + val responseMessage = McpJson.decodeFromString(responseBody) + _onMessage(responseMessage) + } catch (e: Exception) { + logger.error(e) { "Error processing JSON response" } + _onError(e) + } + } + } + + ContentType.Text.EventStream.contentType -> { + logger.trace { "Client received SSE stream in POST response. Messages will be handled by the main SSE session." } + } + + else -> { + logger.trace { "Client received response with unexpected or no content type: ${response.contentType()}" } + } + } + } catch (e: Exception) { + logger.error(e) { "Client send failed." } + _onError(e) + throw e + } + } + + override suspend fun close() { + if (!initialized.load()) { + return // Already closed or never started + } + logger.debug { "Client transport closing." } + + try { + sseSession?.cancel() + sseJob?.cancelAndJoin() + scope.cancel() + } catch (e: Exception) { + // Ignore errors during cleanup + } finally { + _onClose() + } + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt new file mode 100644 index 0000000..718c882 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.ktor.client.HttpClient +import io.ktor.client.request.HttpRequestBuilder +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION +import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME + +/** + * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A [StreamableHttpClientTransport] configured for MCP communication. + */ +public fun HttpClient.mcpStreamableHttpTransport( + url: String, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): StreamableHttpClientTransport = StreamableHttpClientTransport(this, url, requestBuilder) + +/** + * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A connected [Client] ready for MCP communication. + */ +public suspend fun HttpClient.mcpStreamableHttp( + url: String, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): Client { + val transport = mcpStreamableHttpTransport(url, requestBuilder) + val client = Client( + Implementation( + name = IMPLEMENTATION_NAME, + version = LIB_VERSION + ) + ) + client.connect(transport) + return client +} \ No newline at end of file From 0d0403e88c8c59910e7b5d1f0919d97859f78d86 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 14 Jul 2025 01:54:22 +0200 Subject: [PATCH 2/6] update StreamableHttpClientTransport --- api/kotlin-sdk.api | 30 ++ .../client/StreamableHttpClientTransport.kt | 362 +++++++++++++----- .../StreamableHttpMcpKtorClientExtensions.kt | 17 +- .../modelcontextprotocol/kotlin/sdk/types.kt | 10 +- 4 files changed, 309 insertions(+), 110 deletions(-) diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index e04795b..8d356d5 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -1070,6 +1070,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/JSONRPCResponse : io/model public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse$Companion; public fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;)V public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse;Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse; public final fun getError ()Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; public final fun getId ()Lio/modelcontextprotocol/kotlin/sdk/RequestId; public final fun getJsonrpc ()Ljava/lang/String; @@ -2947,6 +2949,34 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTranspor public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getProtocolVersion ()Ljava/lang/String; + public final fun getSessionId ()Ljava/lang/String; + public final fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun send$default (Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun setProtocolVersion (Ljava/lang/String;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun terminateSession (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError : java/lang/Exception { + public fun ()V + public fun (Ljava/lang/Integer;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/Integer;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCode ()Ljava/lang/Integer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; +} + public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index d7f42ca..8baf863 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -2,22 +2,35 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient +import io.ktor.client.plugins.ClientRequestException import io.ktor.client.plugins.sse.ClientSSESession import io.ktor.client.plugins.sse.sseSession import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.delete +import io.ktor.client.request.headers import io.ktor.client.request.post import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel import io.ktor.client.statement.bodyAsText import io.ktor.http.ContentType import io.ktor.http.HttpHeaders -import io.ktor.http.append +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode import io.ktor.http.contentType import io.ktor.http.isSuccess +import io.ktor.utils.io.readUTF8Line import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.RequestId import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel @@ -25,155 +38,302 @@ import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.launch import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.time.Duration private val logger = KotlinLogging.logger {} +private const val MCP_SESSION_ID_HEADER = "mcp-session-id" +private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" +private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" + +/** + * Error class for Streamable HTTP transport errors. + */ +public class StreamableHttpError( + public val code: Int? = null, + message: String? = null +) : Exception("Streamable HTTP error: $message") + /** - * Client transport for Streamable HTTP: this will send messages via HTTP POST requests - * and optionally receive streaming responses via SSE. - * - * This implements the Streamable HTTP transport as specified in MCP 2024-11-05. + * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events + * for receiving messages. */ @OptIn(ExperimentalAtomicApi::class) public class StreamableHttpClientTransport( private val client: HttpClient, private val url: String, + private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractTransport() { + public var sessionId: String? = null + private set + public var protocolVersion: String? = null + private val initialized: AtomicBoolean = AtomicBoolean(false) + private var sseSession: ClientSSESession? = null - private val scope by lazy { CoroutineScope(SupervisorJob()) } private var sseJob: Job? = null - private var sessionId: String? = null + + private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) } + + private var lastEventId: String? = null override suspend fun start() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error("StreamableHttpClientTransport already started!") } logger.debug { "Client transport starting..." } - startSseSession() } - private suspend fun startSseSession() { - logger.debug { "Client attempting to start SSE session at url: $url" } - try { - sseSession = client.sseSession( - urlString = url, - block = requestBuilder, - ) - logger.debug { "Client SSE session started successfully." } - - sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { - sseSession?.incoming?.collect { event -> - logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}" } - when (event.event) { - "error" -> { - val e = IllegalStateException("SSE error: ${event.data}") - logger.error(e) { "SSE stream reported an error event." } - _onError(e) - } - - else -> { - // All non-error events are treated as JSON-RPC messages - try { - val eventData = event.data - if (!eventData.isNullOrEmpty()) { - val message = McpJson.decodeFromString(eventData) - _onMessage(message) - } - } catch (e: Exception) { - logger.error(e) { "Error processing SSE message" } - _onError(e) - } - } - } - } - } - } catch (e: Exception) { - // SSE session is optional, don't fail if it can't be established - // The server might not support GET requests for SSE - logger.warn(e) { "Client failed to start SSE session. This may be expected if the server does not support GET." } - _onError(e) - } + /** + * Sends a single message with optional resumption support + */ + override suspend fun send(message: JSONRPCMessage) { + send(message, null) } - override suspend fun send(message: JSONRPCMessage) { + /** + * Sends one or more messages with optional resumption support. + * This is the main send method that matches the TypeScript implementation. + */ + public suspend fun send( + message: JSONRPCMessage, + resumptionToken: String?, + onResumptionToken: ((String) -> Unit)? = null + ) { logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } - try { - val response = client.post(url) { - requestBuilder() - contentType(ContentType.Application.Json) - headers.append(HttpHeaders.Accept, "${ContentType.Application.Json}, ${ContentType.Text.EventStream}") - // Add session ID if we have one - sessionId?.let { - headers.append("Mcp-Session-Id", it) - } + // If we have a resumption token, reconnect the SSE stream with it + resumptionToken?.let { token -> + startSseSession( + resumptionToken = token, onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null + ) + return + } - setBody(McpJson.encodeToString(message)) - } - logger.debug { "Client received POST response: ${response.status}" } + val jsonBody = McpJson.encodeToString(message) + val response = client.post(url) { + applyCommonHeaders(this) + headers.append(HttpHeaders.Accept, "${ContentType.Application.Json}, ${ContentType.Text.EventStream}") + contentType(ContentType.Application.Json) + setBody(jsonBody) + requestBuilder() + } - if (!response.status.isSuccess()) { + response.headers[MCP_SESSION_ID_HEADER]?.let { sessionId = it } + + if (message is JSONRPCNotification || message is JSONRPCResponse) { + if (response.status != HttpStatusCode.Accepted) { val text = response.bodyAsText() - val error = Exception("HTTP ${response.status}: $text") - logger.error(error) { "Client POST request failed." } - _onError(error) - throw error + val err = StreamableHttpError(response.status.value, text) + logger.error(err) { "Client POST request failed." } + _onError(err) + throw err } + return + } - // Extract session ID from response headers if present - response.headers["Mcp-Session-Id"]?.let { - sessionId = it + when { + !response.status.isSuccess() -> { + val text = response.bodyAsText() + val err = StreamableHttpError(response.status.value, text) + logger.error(err) { "Client POST request failed." } + _onError(err) + throw err } - // Handle response based on content type - when (response.contentType()?.contentType) { - ContentType.Application.Json.contentType -> { - // Single JSON response - val responseBody = response.bodyAsText() - logger.trace { "Client processing JSON response: $responseBody" } - if (responseBody.isNotEmpty()) { - try { - val responseMessage = McpJson.decodeFromString(responseBody) - _onMessage(responseMessage) - } catch (e: Exception) { - logger.error(e) { "Error processing JSON response" } - _onError(e) - } - } + response.contentType()?.match(ContentType.Application.Json) ?: false -> + response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> + runCatching { McpJson.decodeFromString(json) } + .onSuccess { _onMessage(it) } + .onFailure(_onError) } - ContentType.Text.EventStream.contentType -> { - logger.trace { "Client received SSE stream in POST response. Messages will be handled by the main SSE session." } - } + response.contentType()?.match(ContentType.Text.EventStream) ?: false -> + handleInlineSse( + response, onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null + ) + } - else -> { - logger.trace { "Client received response with unexpected or no content type: ${response.contentType()}" } - } - } - } catch (e: Exception) { - logger.error(e) { "Client send failed." } - _onError(e) - throw e + // If client just sent InitializedNotification, open SSE stream + if (message is JSONRPCNotification && message.method == "notifications/initialized" && sseSession == null) { + startSseSession() } } override suspend fun close() { - if (!initialized.load()) { - return // Already closed or never started - } + if (!initialized.load()) return // Already closed or never started logger.debug { "Client transport closing." } try { + // Try to terminate session if we have one + terminateSession() + sseSession?.cancel() sseJob?.cancelAndJoin() scope.cancel() - } catch (e: Exception) { + } catch (_: Exception) { // Ignore errors during cleanup } finally { + initialized.store(false) _onClose() } } -} \ No newline at end of file + + /** + * Terminates the current session by sending a DELETE request to the server. + */ + public suspend fun terminateSession() { + if (sessionId == null) return + logger.debug { "Terminating session: $sessionId" } + val response = client.delete(url) { + applyCommonHeaders(this) + requestBuilder() + } + + // 405 means server doesn't support explicit session termination + if (!response.status.isSuccess() && response.status != HttpStatusCode.MethodNotAllowed) { + val error = StreamableHttpError( + response.status.value, + "Failed to terminate session: ${response.status.description}" + ) + logger.error(error) { "Failed to terminate session" } + _onError(error) + throw error + } + + sessionId = null + logger.debug { "Session terminated successfully" } + } + + private suspend fun startSseSession( + resumptionToken: String? = null, + replayMessageId: RequestId? = null, + onResumptionToken: ((String) -> Unit)? = null + ) { + sseSession?.cancel() + sseJob?.cancelAndJoin() + + logger.debug { "Client attempting to start SSE session at url: $url" } + try { + sseSession = client.sseSession( + urlString = url, + reconnectionTime = reconnectionTime, + ) { + method = HttpMethod.Get + applyCommonHeaders(this) + headers.append(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } + requestBuilder() + } + logger.debug { "Client SSE session started successfully." } + } catch (e: ClientRequestException) { + if (e.response.status == HttpStatusCode.MethodNotAllowed) { + logger.info { "Server returned 405 for GET/SSE, stream disabled." } + return + } + _onError(e) + throw e + } + + sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { + sseSession?.let { collectSse(it, replayMessageId, onResumptionToken) } + } + } + + private fun applyCommonHeaders(builder: HttpRequestBuilder) { + builder.headers { + sessionId?.let { append(MCP_SESSION_ID_HEADER, it) } + protocolVersion?.let { append(MCP_PROTOCOL_VERSION_HEADER, it) } + } + } + + private suspend fun collectSse( + session: ClientSSESession, + replayMessageId: RequestId?, + onResumptionToken: ((String) -> Unit)? + ) { + try { + session.incoming.collect { event -> + event.id?.let { + lastEventId = it + onResumptionToken?.invoke(it) + } + logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}, id=${event.id}" } + when (event.event) { + null, "message" -> + event.data?.takeIf { it.isNotEmpty() }?.let { json -> + runCatching { McpJson.decodeFromString(json) } + .onSuccess { msg -> + if (replayMessageId != null && msg is JSONRPCResponse) { + _onMessage(msg.copy(id = replayMessageId)) + } else { + _onMessage(msg) + } + } + .onFailure(_onError) + } + + "error" -> _onError(StreamableHttpError(null, event.data)) + } + } + } catch (_: CancellationException) { + // ignore + } catch (t: Throwable) { + _onError(t) + } + } + + private suspend fun handleInlineSse( + response: HttpResponse, + replayMessageId: RequestId?, + onResumptionToken: ((String) -> Unit)? + ) { + logger.trace { "Handling inline SSE from POST response" } + val channel = response.bodyAsChannel() + val reader = channel + + val sb = StringBuilder() + var id: String? = null + var eventName: String? = null + + suspend fun dispatch(data: String) { + id?.let { + lastEventId = it + onResumptionToken?.invoke(it) + } + if (eventName == null || eventName == "message") { + runCatching { McpJson.decodeFromString(data) } + .onSuccess { msg -> + if (replayMessageId != null && msg is JSONRPCResponse) { + _onMessage(msg.copy(id = replayMessageId)) + } else { + _onMessage(msg) + } + } + .onFailure(_onError) + } + // reset + id = null + eventName = null + sb.clear() + } + + while (!reader.isClosedForRead) { + val line = reader.readUTF8Line() ?: break + if (line.isEmpty()) { + dispatch(sb.toString()) + continue + } + when { + line.startsWith("id:") -> line.substring(3).trim() + line.startsWith("event:") -> eventName = line.substring(6).trim() + line.startsWith("data:") -> sb.append(line.substring(5).trim()) + } + } + } +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt index 718c882..c2454e1 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -5,37 +5,38 @@ import io.ktor.client.request.HttpRequestBuilder import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +import kotlin.time.Duration /** * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. * * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. * @param requestBuilder Optional lambda to configure the HTTP request. * @return A [StreamableHttpClientTransport] configured for MCP communication. */ public fun HttpClient.mcpStreamableHttpTransport( url: String, + reconnectionTime: Duration? = null, requestBuilder: HttpRequestBuilder.() -> Unit = {}, -): StreamableHttpClientTransport = StreamableHttpClientTransport(this, url, requestBuilder) +): StreamableHttpClientTransport = + StreamableHttpClientTransport(this, url, reconnectionTime, requestBuilder = requestBuilder) /** * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. * * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. * @param requestBuilder Optional lambda to configure the HTTP request. * @return A connected [Client] ready for MCP communication. */ public suspend fun HttpClient.mcpStreamableHttp( url: String, + reconnectionTime: Duration? = null, requestBuilder: HttpRequestBuilder.() -> Unit = {}, ): Client { - val transport = mcpStreamableHttpTransport(url, requestBuilder) - val client = Client( - Implementation( - name = IMPLEMENTATION_NAME, - version = LIB_VERSION - ) - ) + val transport = mcpStreamableHttpTransport(url, reconnectionTime, requestBuilder) + val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) client.connect(transport) return client } \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 893a8d6..c93519b 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -244,7 +244,15 @@ public class JSONRPCResponse( public val jsonrpc: String = JSONRPC_VERSION, public val result: RequestResult? = null, public val error: JSONRPCError? = null, -) : JSONRPCMessage +) : JSONRPCMessage { + + public fun copy( + id: RequestId = this.id, + jsonrpc: String = this.jsonrpc, + result: RequestResult? = this.result, + error: JSONRPCError? = this.error, + ): JSONRPCResponse = JSONRPCResponse(id, jsonrpc, result, error) +} /** * An incomplete set of error codes that may appear in JSON-RPC responses. From 5e4851d9e62af46636c1d310a36d37401030b6bb Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 14 Jul 2025 02:03:46 +0200 Subject: [PATCH 3/6] improve StreamableHttpClientTransport handling of headers, event parsing, and session termination --- .../kotlin/sdk/client/StreamableHttpClientTransport.kt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 8baf863..cc07c84 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -207,6 +207,7 @@ public class StreamableHttpClientTransport( } sessionId = null + lastEventId = null logger.debug { "Session terminated successfully" } } @@ -247,6 +248,7 @@ public class StreamableHttpClientTransport( private fun applyCommonHeaders(builder: HttpRequestBuilder) { builder.headers { + append(HttpHeaders.Accept, ContentType.Application.Json.toString()) sessionId?.let { append(MCP_SESSION_ID_HEADER, it) } protocolVersion?.let { append(MCP_PROTOCOL_VERSION_HEADER, it) } } @@ -330,9 +332,9 @@ public class StreamableHttpClientTransport( continue } when { - line.startsWith("id:") -> line.substring(3).trim() - line.startsWith("event:") -> eventName = line.substring(6).trim() - line.startsWith("data:") -> sb.append(line.substring(5).trim()) + line.startsWith("id:") -> id = line.substringAfter("id:").trim() + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + line.startsWith("data:") -> sb.appendLine(line.substringAfter("data:").trim()) } } } From cd26c917635183b17a76d270cdb909d8e2d83351 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 14 Jul 2025 03:38:17 +0200 Subject: [PATCH 4/6] add simple tests for StreamableClient and fix send --- build.gradle.kts | 1 + .../client/StreamableHttpClientTransport.kt | 65 +++--- .../StreamableHttpClientTransportTest.kt | 209 ++++++++++++++++++ 3 files changed, 240 insertions(+), 35 deletions(-) create mode 100644 src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index 1f826bd..c1366a3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -268,6 +268,7 @@ kotlin { jvmTest { dependencies { + implementation(libs.ktor.client.mock) implementation(libs.mockk) implementation(libs.slf4j.simple) } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index cc07c84..c6c8a23 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -125,43 +125,39 @@ public class StreamableHttpClientTransport( response.headers[MCP_SESSION_ID_HEADER]?.let { sessionId = it } - if (message is JSONRPCNotification || message is JSONRPCResponse) { - if (response.status != HttpStatusCode.Accepted) { - val text = response.bodyAsText() - val err = StreamableHttpError(response.status.value, text) - logger.error(err) { "Client POST request failed." } - _onError(err) - throw err + if (response.status == HttpStatusCode.Accepted) { + if (message is JSONRPCNotification && message.method == "notifications/initialized") { + startSseSession(onResumptionToken = onResumptionToken) } return } - when { - !response.status.isSuccess() -> { - val text = response.bodyAsText() - val err = StreamableHttpError(response.status.value, text) - logger.error(err) { "Client POST request failed." } - _onError(err) - throw err - } - - response.contentType()?.match(ContentType.Application.Json) ?: false -> - response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> - runCatching { McpJson.decodeFromString(json) } - .onSuccess { _onMessage(it) } - .onFailure(_onError) - } - - response.contentType()?.match(ContentType.Text.EventStream) ?: false -> - handleInlineSse( - response, onResumptionToken = onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null - ) + if (!response.status.isSuccess()) { + val error = StreamableHttpError(response.status.value, response.bodyAsText()) + _onError(error) + throw error } - // If client just sent InitializedNotification, open SSE stream - if (message is JSONRPCNotification && message.method == "notifications/initialized" && sseSession == null) { - startSseSession() + when (response.contentType()?.withoutParameters()) { + ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> + runCatching { McpJson.decodeFromString(json) } + .onSuccess { _onMessage(it) } + .onFailure(_onError) + } + + ContentType.Text.EventStream -> handleInlineSse( + response, onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null + ) + else -> { + val body = response.bodyAsText() + if (response.contentType() == null && body.isBlank()) return + + val ct = response.contentType()?.toString() ?: "" + val error = StreamableHttpError(-1, "Unexpected content type: $$ct") + _onError(error) + throw error + } } } @@ -297,7 +293,6 @@ public class StreamableHttpClientTransport( ) { logger.trace { "Handling inline SSE from POST response" } val channel = response.bodyAsChannel() - val reader = channel val sb = StringBuilder() var id: String? = null @@ -325,8 +320,8 @@ public class StreamableHttpClientTransport( sb.clear() } - while (!reader.isClosedForRead) { - val line = reader.readUTF8Line() ?: break + while (!channel.isClosedForRead) { + val line = channel.readUTF8Line() ?: break if (line.isEmpty()) { dispatch(sb.toString()) continue @@ -334,7 +329,7 @@ public class StreamableHttpClientTransport( when { line.startsWith("id:") -> id = line.substringAfter("id:").trim() line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() - line.startsWith("data:") -> sb.appendLine(line.substringAfter("data:").trim()) + line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim()) } } } diff --git a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt new file mode 100644 index 0000000..b5a4db0 --- /dev/null +++ b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt @@ -0,0 +1,209 @@ +package client + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.client.plugins.sse.SSE +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.content.TextContent +import io.ktor.utils.io.ByteReadChannel +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.buildJsonObject +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.time.Duration.Companion.seconds + +class StreamableHttpClientTransportTest { + private lateinit var mockEngine: MockEngine + private lateinit var httpClient: HttpClient + private lateinit var transport: StreamableHttpClientTransport + + @BeforeTest + fun setup() { + mockEngine = MockEngine { + respond( + ByteReadChannel(""), + status = HttpStatusCode.OK, + ) + } + + httpClient = HttpClient(mockEngine) { + install(SSE) { + reconnectionTime = 1.seconds + } + } + + transport = StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp") + } + + @AfterTest + fun teardown() { + httpClient.close() + } + + @Test + fun testSendJsonRpcMessage() = runTest { + val message = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "test", + params = buildJsonObject { } + ) + + mockEngine.config.addHandler { request -> + assertEquals(HttpMethod.Post, request.method) + assertEquals("http://localhost:8080/mcp", request.url.toString()) + assertEquals(ContentType.Application.Json, request.body.contentType) + + val body = (request.body as TextContent).text + val decodedMessage = McpJson.decodeFromString(body) + assertEquals(message, decodedMessage) + + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + transport.start() + transport.send(message) + } + +// @Test +// fun testStoreSessionId() = runTest { +// val initMessage = JSONRPCRequest( +// id = RequestId.StringId("test-id"), +// method = "initialize", +// params = buildJsonObject { +// put("clientInfo", buildJsonObject { +// put("name", JsonPrimitive("test-client")) +// put("version", JsonPrimitive("1.0")) +// }) +// put("protocolVersion", JsonPrimitive("2025-06-18")) +// } +// ) +// +// mockEngine.config.addHandler { request -> +// respond( +// content = "", status = HttpStatusCode.OK, +// headers = headersOf("mcp-session-id", "test-session-id") +// ) +// } +// +// transport.start() +// transport.send(initMessage) +// +// assertEquals("test-session-id", transport.sessionId) +// +// // Send another message and verify session ID is included +// mockEngine.config.addHandler { request -> +// assertEquals("test-session-id", request.headers["mcp-session-id"]) +// respond( +// content = "", +// status = HttpStatusCode.Accepted +// ) +// } +// +// transport.send(JSONRPCNotification(method = "test")) +// } + + @Test + fun testTerminateSession() = runTest { +// transport.sessionId = "test-session-id" + + mockEngine.config.addHandler { request -> + assertEquals(HttpMethod.Delete, request.method) + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.OK + ) + } + + transport.start() + transport.terminateSession() + + assertNull(transport.sessionId) + } + + @Test + fun testTerminateSessionHandle405() = runTest { +// transport.sessionId = "test-session-id" + + mockEngine.config.addHandler { request -> + assertEquals(HttpMethod.Delete, request.method) + respond( + content = "", + status = HttpStatusCode.MethodNotAllowed + ) + } + + transport.start() + // Should not throw for 405 + assertDoesNotThrow { + transport.terminateSession() + } + + // Session ID should still be cleared + assertNull(transport.sessionId) + } + + @Test + fun testProtocolVersionHeader() = runTest { + transport.protocolVersion = "2025-06-18" + + mockEngine.config.addHandler { request -> + assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + transport.start() + transport.send(JSONRPCNotification(method = "test")) + } + + @Test + fun testHandle405ForSSE() = runTest { + mockEngine.config.addHandler { request -> + if (request.method == HttpMethod.Get) { + respond( + content = "", + status = HttpStatusCode.MethodNotAllowed + ) + } else { + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + } + + transport.start() + + // Start SSE session - should handle 405 gracefully + val initNotification = JSONRPCNotification( + method = "notifications/initialized", + ) + + // Should not throw + assertDoesNotThrow { + transport.send(initNotification) + } + + // Transport should still work after 405 + transport.send(JSONRPCNotification(method = "test")) + } +} From b4da239fa73e5894a1ecb17f09d9f5e5ffa28a6c Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 14 Jul 2025 03:38:29 +0200 Subject: [PATCH 5/6] add client mocking --- gradle/libs.versions.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 40dc628..1227aab 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -32,6 +32,7 @@ ktor-server-cio = { group = "io.ktor", name = "ktor-server-cio", version.ref = " kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "coroutines" } ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } +ktor-client-mock = { group = "io.ktor", name = "ktor-client-mock", version.ref = "ktor" } mockk = { group = "io.mockk", name = "mockk", version.ref = "mockk" } slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", version.ref = "kotest" } From 568da8552473c898fe35993c645b1271fec89313 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 14 Jul 2025 14:23:20 +0200 Subject: [PATCH 6/6] add notification e2e tests --- .../client/StreamableHttpClientTransport.kt | 4 +- .../StreamableHttpClientTransportTest.kt | 326 +++++++++++++----- 2 files changed, 245 insertions(+), 85 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index c6c8a23..6584bc1 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -6,6 +6,7 @@ import io.ktor.client.plugins.ClientRequestException import io.ktor.client.plugins.sse.ClientSSESession import io.ktor.client.plugins.sse.sseSession import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.accept import io.ktor.client.request.delete import io.ktor.client.request.headers import io.ktor.client.request.post @@ -223,7 +224,7 @@ public class StreamableHttpClientTransport( ) { method = HttpMethod.Get applyCommonHeaders(this) - headers.append(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + accept(ContentType.Text.EventStream) (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } requestBuilder() } @@ -244,7 +245,6 @@ public class StreamableHttpClientTransport( private fun applyCommonHeaders(builder: HttpRequestBuilder) { builder.headers { - append(HttpHeaders.Accept, ContentType.Application.Json.toString()) sessionId?.let { append(MCP_SESSION_ID_HEADER, it) } protocolVersion?.let { append(MCP_PROTOCOL_VERSION_HEADER, it) } } diff --git a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt index b5a4db0..ca255f4 100644 --- a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt +++ b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt @@ -2,12 +2,15 @@ package client import io.ktor.client.HttpClient import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.MockRequestHandler import io.ktor.client.engine.mock.respond import io.ktor.client.plugins.sse.SSE import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.http.content.TextContent +import io.ktor.http.headersOf import io.ktor.utils.io.ByteReadChannel import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification @@ -15,42 +18,30 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.RequestId import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.delay import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject import org.junit.jupiter.api.assertDoesNotThrow -import kotlin.test.AfterTest -import kotlin.test.BeforeTest +import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNull +import kotlin.test.assertTrue import kotlin.time.Duration.Companion.seconds class StreamableHttpClientTransportTest { - private lateinit var mockEngine: MockEngine - private lateinit var httpClient: HttpClient - private lateinit var transport: StreamableHttpClientTransport - @BeforeTest - fun setup() { - mockEngine = MockEngine { - respond( - ByteReadChannel(""), - status = HttpStatusCode.OK, - ) - } - - httpClient = HttpClient(mockEngine) { + private fun createTransport(handler: MockRequestHandler): StreamableHttpClientTransport { + val mockEngine = MockEngine(handler) + val httpClient = HttpClient(mockEngine) { install(SSE) { reconnectionTime = 1.seconds } } - transport = StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp") - } - - @AfterTest - fun teardown() { - httpClient.close() + return StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp") } @Test @@ -61,7 +52,7 @@ class StreamableHttpClientTransportTest { params = buildJsonObject { } ) - mockEngine.config.addHandler { request -> + val transport = createTransport { request -> assertEquals(HttpMethod.Post, request.method) assertEquals("http://localhost:8080/mcp", request.url.toString()) assertEquals(ContentType.Application.Json, request.body.contentType) @@ -78,51 +69,57 @@ class StreamableHttpClientTransportTest { transport.start() transport.send(message) + transport.close() } -// @Test -// fun testStoreSessionId() = runTest { -// val initMessage = JSONRPCRequest( -// id = RequestId.StringId("test-id"), -// method = "initialize", -// params = buildJsonObject { -// put("clientInfo", buildJsonObject { -// put("name", JsonPrimitive("test-client")) -// put("version", JsonPrimitive("1.0")) -// }) -// put("protocolVersion", JsonPrimitive("2025-06-18")) -// } -// ) -// -// mockEngine.config.addHandler { request -> -// respond( -// content = "", status = HttpStatusCode.OK, -// headers = headersOf("mcp-session-id", "test-session-id") -// ) -// } -// -// transport.start() -// transport.send(initMessage) -// -// assertEquals("test-session-id", transport.sessionId) -// -// // Send another message and verify session ID is included -// mockEngine.config.addHandler { request -> -// assertEquals("test-session-id", request.headers["mcp-session-id"]) -// respond( -// content = "", -// status = HttpStatusCode.Accepted -// ) -// } -// -// transport.send(JSONRPCNotification(method = "test")) -// } + @Test + fun testStoreSessionId() = runTest { + val initMessage = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "initialize", + params = buildJsonObject { + put("clientInfo", buildJsonObject { + put("name", JsonPrimitive("test-client")) + put("version", JsonPrimitive("1.0")) + }) + put("protocolVersion", JsonPrimitive("2025-06-18")) + } + ) + + val transport = createTransport { request -> + when (val msg = McpJson.decodeFromString((request.body as TextContent).text)) { + is JSONRPCRequest if msg.method == "initialize" -> respond( + content = "", status = HttpStatusCode.OK, + headers = headersOf("mcp-session-id", "test-session-id") + ) + + is JSONRPCNotification if msg.method == "test" -> { + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + else -> error("Unexpected message: $msg") + } + } + + transport.start() + transport.send(initMessage) + + assertEquals("test-session-id", transport.sessionId) + + transport.send(JSONRPCNotification(method = "test")) + + transport.close() + } @Test fun testTerminateSession() = runTest { // transport.sessionId = "test-session-id" - mockEngine.config.addHandler { request -> + val transport = createTransport { request -> assertEquals(HttpMethod.Delete, request.method) assertEquals("test-session-id", request.headers["mcp-session-id"]) respond( @@ -135,13 +132,14 @@ class StreamableHttpClientTransportTest { transport.terminateSession() assertNull(transport.sessionId) + transport.close() } @Test fun testTerminateSessionHandle405() = runTest { // transport.sessionId = "test-session-id" - mockEngine.config.addHandler { request -> + val transport = createTransport { request -> assertEquals(HttpMethod.Delete, request.method) respond( content = "", @@ -157,53 +155,215 @@ class StreamableHttpClientTransportTest { // Session ID should still be cleared assertNull(transport.sessionId) + transport.close() } @Test fun testProtocolVersionHeader() = runTest { - transport.protocolVersion = "2025-06-18" - - mockEngine.config.addHandler { request -> + val transport = createTransport { request -> assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) respond( content = "", status = HttpStatusCode.Accepted ) } + transport.protocolVersion = "2025-06-18" transport.start() transport.send(JSONRPCNotification(method = "test")) + transport.close() } + @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") @Test - fun testHandle405ForSSE() = runTest { - mockEngine.config.addHandler { request -> - if (request.method == HttpMethod.Get) { - respond( - content = "", - status = HttpStatusCode.MethodNotAllowed - ) - } else { - respond( - content = "", - status = HttpStatusCode.Accepted - ) + fun testNotificationSchemaE2E() = runTest { + val receivedMessages = mutableListOf() + var sseStarted = false + + val transport = createTransport { request -> + when (request.method) { + HttpMethod.Post if request.body.toString().contains("notifications/initialized") -> { + respond( + content = "", + status = HttpStatusCode.Accepted, + headers = headersOf("mcp-session-id", "notification-test-session") + ) + } + + // Handle SSE connection + HttpMethod.Get -> { + sseStarted = true + val sseContent = buildString { + // Server sends various notifications + appendLine("event: message") + appendLine("id: 1") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""") + appendLine() + + appendLine("event: message") + appendLine("id: 2") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resources/list_changed"}""") + appendLine() + + appendLine("event: message") + appendLine("id: 3") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/tools/list_changed"}""") + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, ContentType.Text.EventStream.toString() + ) + ) + } + + // Handle regular notifications + HttpMethod.Post -> { + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + else -> respond("", HttpStatusCode.OK) } } + transport.onMessage { message -> + receivedMessages.add(message) + } + transport.start() - // Start SSE session - should handle 405 gracefully - val initNotification = JSONRPCNotification( + // Test 1: Send initialized notification to trigger SSE + val initializedNotification = JSONRPCNotification( method = "notifications/initialized", + params = buildJsonObject { + put("protocolVersion", JsonPrimitive("1.0")) + put("capabilities", buildJsonObject { + put("tools", JsonPrimitive(true)) + put("resources", JsonPrimitive(true)) + }) + } ) - // Should not throw - assertDoesNotThrow { - transport.send(initNotification) + transport.send(initializedNotification) + + // Verify SSE was triggered + assertTrue(sseStarted, "SSE should start after initialized notification") + + // Test 2: Verify received notifications + assertEquals(3, receivedMessages.size) + assertTrue(receivedMessages.all { it is JSONRPCNotification }) + + val notifications = receivedMessages.filterIsInstance() + + // Verify progress notification + val progressNotif = notifications[0] + assertEquals("notifications/progress", progressNotif.method) + val progressParams = progressNotif.params as JsonObject + assertEquals("upload-123", (progressParams["progressToken"] as JsonPrimitive).content) + assertEquals(50, (progressParams["progress"] as JsonPrimitive).content.toInt()) + + // Verify list changed notifications + assertEquals("notifications/resources/list_changed", notifications[1].method) + assertEquals("notifications/tools/list_changed", notifications[2].method) + + // Test 3: Send various client notifications + val clientNotifications = listOf( + JSONRPCNotification( + method = "notifications/progress", + params = buildJsonObject { + put("progressToken", JsonPrimitive("download-456")) + put("progress", JsonPrimitive(75)) + } + ), + JSONRPCNotification( + method = "notifications/cancelled", + params = buildJsonObject { + put("requestId", JsonPrimitive("req-789")) + put("reason", JsonPrimitive("user_cancelled")) + } + ), + JSONRPCNotification( + method = "notifications/message", + params = buildJsonObject { + put("level", JsonPrimitive("info")) + put("message", JsonPrimitive("Operation completed")) + put("data", buildJsonObject { + put("duration", JsonPrimitive(1234)) + }) + } + ) + ) + + // Send all client notifications + clientNotifications.forEach { notification -> + transport.send(notification) } - // Transport should still work after 405 - transport.send(JSONRPCNotification(method = "test")) + // Verify session ID is maintained + assertEquals("notification-test-session", transport.sessionId) + transport.close() + } + + @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") + @Test + fun testNotificationWithResumptionToken() = runTest { + var resumptionTokenReceived: String? = null + var lastEventIdSent: String? = null + + val transport = createTransport { request -> + // Capture Last-Event-ID header + lastEventIdSent = request.headers["Last-Event-ID"] + + when (request.method) { + HttpMethod.Get -> { + val sseContent = buildString { + appendLine("event: message") + appendLine("id: resume-100") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"${lastEventIdSent}"}}""") + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, ContentType.Text.EventStream.toString() + ) + ) + } + + else -> respond("", HttpStatusCode.Accepted) + } + } + + transport.start() + + // Send notification with resumption token + transport.send( + message = JSONRPCNotification( + method = "notifications/test", + params = buildJsonObject { + put("data", JsonPrimitive("test-data")) + } + ), + resumptionToken = "previous-token-99", + onResumptionToken = { token -> + resumptionTokenReceived = token + } + ) + + // Wait for response + delay(1.seconds) + + // Verify resumption token was sent in header + assertEquals("previous-token-99", lastEventIdSent) + + // Verify new resumption token was received + assertEquals("resume-100", resumptionTokenReceived) + transport.close() } }