diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index e04795b..8d356d5 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -1070,6 +1070,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/JSONRPCResponse : io/model public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse$Companion; public fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;)V public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse;Lio/modelcontextprotocol/kotlin/sdk/RequestId;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/RequestResult;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCResponse; public final fun getError ()Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; public final fun getId ()Lio/modelcontextprotocol/kotlin/sdk/RequestId; public final fun getJsonrpc ()Ljava/lang/String; @@ -2947,6 +2949,34 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTranspor public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getProtocolVersion ()Ljava/lang/String; + public final fun getSessionId ()Ljava/lang/String; + public final fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun send$default (Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun setProtocolVersion (Ljava/lang/String;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun terminateSession (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError : java/lang/Exception { + public fun ()V + public fun (Ljava/lang/Integer;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/Integer;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCode ()Ljava/lang/Integer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; +} + public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/build.gradle.kts b/build.gradle.kts index 1f826bd..c1366a3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -268,6 +268,7 @@ kotlin { jvmTest { dependencies { + implementation(libs.ktor.client.mock) implementation(libs.mockk) implementation(libs.slf4j.simple) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 40dc628..1227aab 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -32,6 +32,7 @@ ktor-server-cio = { group = "io.ktor", name = "ktor-server-cio", version.ref = " kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "coroutines" } ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } +ktor-client-mock = { group = "io.ktor", name = "ktor-client-mock", version.ref = "ktor" } mockk = { group = "io.mockk", name = "mockk", version.ref = "mockk" } slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", version.ref = "kotest" } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt new file mode 100644 index 0000000..6584bc1 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -0,0 +1,336 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.plugins.ClientRequestException +import io.ktor.client.plugins.sse.ClientSSESession +import io.ktor.client.plugins.sse.sseSession +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.accept +import io.ktor.client.request.delete +import io.ktor.client.request.headers +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.http.isSuccess +import io.ktor.utils.io.readUTF8Line +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.time.Duration + +private val logger = KotlinLogging.logger {} + +private const val MCP_SESSION_ID_HEADER = "mcp-session-id" +private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" +private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" + +/** + * Error class for Streamable HTTP transport errors. + */ +public class StreamableHttpError( + public val code: Int? = null, + message: String? = null +) : Exception("Streamable HTTP error: $message") + +/** + * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events + * for receiving messages. + */ +@OptIn(ExperimentalAtomicApi::class) +public class StreamableHttpClientTransport( + private val client: HttpClient, + private val url: String, + private val reconnectionTime: Duration? = null, + private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, +) : AbstractTransport() { + + public var sessionId: String? = null + private set + public var protocolVersion: String? = null + + private val initialized: AtomicBoolean = AtomicBoolean(false) + + private var sseSession: ClientSSESession? = null + private var sseJob: Job? = null + + private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) } + + private var lastEventId: String? = null + + override suspend fun start() { + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { + error("StreamableHttpClientTransport already started!") + } + logger.debug { "Client transport starting..." } + } + + /** + * Sends a single message with optional resumption support + */ + override suspend fun send(message: JSONRPCMessage) { + send(message, null) + } + + /** + * Sends one or more messages with optional resumption support. + * This is the main send method that matches the TypeScript implementation. + */ + public suspend fun send( + message: JSONRPCMessage, + resumptionToken: String?, + onResumptionToken: ((String) -> Unit)? = null + ) { + logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } + + // If we have a resumption token, reconnect the SSE stream with it + resumptionToken?.let { token -> + startSseSession( + resumptionToken = token, onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null + ) + return + } + + val jsonBody = McpJson.encodeToString(message) + val response = client.post(url) { + applyCommonHeaders(this) + headers.append(HttpHeaders.Accept, "${ContentType.Application.Json}, ${ContentType.Text.EventStream}") + contentType(ContentType.Application.Json) + setBody(jsonBody) + requestBuilder() + } + + response.headers[MCP_SESSION_ID_HEADER]?.let { sessionId = it } + + if (response.status == HttpStatusCode.Accepted) { + if (message is JSONRPCNotification && message.method == "notifications/initialized") { + startSseSession(onResumptionToken = onResumptionToken) + } + return + } + + if (!response.status.isSuccess()) { + val error = StreamableHttpError(response.status.value, response.bodyAsText()) + _onError(error) + throw error + } + + when (response.contentType()?.withoutParameters()) { + ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> + runCatching { McpJson.decodeFromString(json) } + .onSuccess { _onMessage(it) } + .onFailure(_onError) + } + + ContentType.Text.EventStream -> handleInlineSse( + response, onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null + ) + else -> { + val body = response.bodyAsText() + if (response.contentType() == null && body.isBlank()) return + + val ct = response.contentType()?.toString() ?: "" + val error = StreamableHttpError(-1, "Unexpected content type: $$ct") + _onError(error) + throw error + } + } + } + + override suspend fun close() { + if (!initialized.load()) return // Already closed or never started + logger.debug { "Client transport closing." } + + try { + // Try to terminate session if we have one + terminateSession() + + sseSession?.cancel() + sseJob?.cancelAndJoin() + scope.cancel() + } catch (_: Exception) { + // Ignore errors during cleanup + } finally { + initialized.store(false) + _onClose() + } + } + + /** + * Terminates the current session by sending a DELETE request to the server. + */ + public suspend fun terminateSession() { + if (sessionId == null) return + logger.debug { "Terminating session: $sessionId" } + val response = client.delete(url) { + applyCommonHeaders(this) + requestBuilder() + } + + // 405 means server doesn't support explicit session termination + if (!response.status.isSuccess() && response.status != HttpStatusCode.MethodNotAllowed) { + val error = StreamableHttpError( + response.status.value, + "Failed to terminate session: ${response.status.description}" + ) + logger.error(error) { "Failed to terminate session" } + _onError(error) + throw error + } + + sessionId = null + lastEventId = null + logger.debug { "Session terminated successfully" } + } + + private suspend fun startSseSession( + resumptionToken: String? = null, + replayMessageId: RequestId? = null, + onResumptionToken: ((String) -> Unit)? = null + ) { + sseSession?.cancel() + sseJob?.cancelAndJoin() + + logger.debug { "Client attempting to start SSE session at url: $url" } + try { + sseSession = client.sseSession( + urlString = url, + reconnectionTime = reconnectionTime, + ) { + method = HttpMethod.Get + applyCommonHeaders(this) + accept(ContentType.Text.EventStream) + (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } + requestBuilder() + } + logger.debug { "Client SSE session started successfully." } + } catch (e: ClientRequestException) { + if (e.response.status == HttpStatusCode.MethodNotAllowed) { + logger.info { "Server returned 405 for GET/SSE, stream disabled." } + return + } + _onError(e) + throw e + } + + sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { + sseSession?.let { collectSse(it, replayMessageId, onResumptionToken) } + } + } + + private fun applyCommonHeaders(builder: HttpRequestBuilder) { + builder.headers { + sessionId?.let { append(MCP_SESSION_ID_HEADER, it) } + protocolVersion?.let { append(MCP_PROTOCOL_VERSION_HEADER, it) } + } + } + + private suspend fun collectSse( + session: ClientSSESession, + replayMessageId: RequestId?, + onResumptionToken: ((String) -> Unit)? + ) { + try { + session.incoming.collect { event -> + event.id?.let { + lastEventId = it + onResumptionToken?.invoke(it) + } + logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}, id=${event.id}" } + when (event.event) { + null, "message" -> + event.data?.takeIf { it.isNotEmpty() }?.let { json -> + runCatching { McpJson.decodeFromString(json) } + .onSuccess { msg -> + if (replayMessageId != null && msg is JSONRPCResponse) { + _onMessage(msg.copy(id = replayMessageId)) + } else { + _onMessage(msg) + } + } + .onFailure(_onError) + } + + "error" -> _onError(StreamableHttpError(null, event.data)) + } + } + } catch (_: CancellationException) { + // ignore + } catch (t: Throwable) { + _onError(t) + } + } + + private suspend fun handleInlineSse( + response: HttpResponse, + replayMessageId: RequestId?, + onResumptionToken: ((String) -> Unit)? + ) { + logger.trace { "Handling inline SSE from POST response" } + val channel = response.bodyAsChannel() + + val sb = StringBuilder() + var id: String? = null + var eventName: String? = null + + suspend fun dispatch(data: String) { + id?.let { + lastEventId = it + onResumptionToken?.invoke(it) + } + if (eventName == null || eventName == "message") { + runCatching { McpJson.decodeFromString(data) } + .onSuccess { msg -> + if (replayMessageId != null && msg is JSONRPCResponse) { + _onMessage(msg.copy(id = replayMessageId)) + } else { + _onMessage(msg) + } + } + .onFailure(_onError) + } + // reset + id = null + eventName = null + sb.clear() + } + + while (!channel.isClosedForRead) { + val line = channel.readUTF8Line() ?: break + if (line.isEmpty()) { + dispatch(sb.toString()) + continue + } + when { + line.startsWith("id:") -> id = line.substringAfter("id:").trim() + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim()) + } + } + } +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt new file mode 100644 index 0000000..c2454e1 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.ktor.client.HttpClient +import io.ktor.client.request.HttpRequestBuilder +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION +import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +import kotlin.time.Duration + +/** + * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A [StreamableHttpClientTransport] configured for MCP communication. + */ +public fun HttpClient.mcpStreamableHttpTransport( + url: String, + reconnectionTime: Duration? = null, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): StreamableHttpClientTransport = + StreamableHttpClientTransport(this, url, reconnectionTime, requestBuilder = requestBuilder) + +/** + * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A connected [Client] ready for MCP communication. + */ +public suspend fun HttpClient.mcpStreamableHttp( + url: String, + reconnectionTime: Duration? = null, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): Client { + val transport = mcpStreamableHttpTransport(url, reconnectionTime, requestBuilder) + val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) + client.connect(transport) + return client +} \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 893a8d6..c93519b 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -244,7 +244,15 @@ public class JSONRPCResponse( public val jsonrpc: String = JSONRPC_VERSION, public val result: RequestResult? = null, public val error: JSONRPCError? = null, -) : JSONRPCMessage +) : JSONRPCMessage { + + public fun copy( + id: RequestId = this.id, + jsonrpc: String = this.jsonrpc, + result: RequestResult? = this.result, + error: JSONRPCError? = this.error, + ): JSONRPCResponse = JSONRPCResponse(id, jsonrpc, result, error) +} /** * An incomplete set of error codes that may appear in JSON-RPC responses. diff --git a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt new file mode 100644 index 0000000..ca255f4 --- /dev/null +++ b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt @@ -0,0 +1,369 @@ +package client + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.MockRequestHandler +import io.ktor.client.engine.mock.respond +import io.ktor.client.plugins.sse.SSE +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteReadChannel +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.delay +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.Ignore +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class StreamableHttpClientTransportTest { + + private fun createTransport(handler: MockRequestHandler): StreamableHttpClientTransport { + val mockEngine = MockEngine(handler) + val httpClient = HttpClient(mockEngine) { + install(SSE) { + reconnectionTime = 1.seconds + } + } + + return StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp") + } + + @Test + fun testSendJsonRpcMessage() = runTest { + val message = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "test", + params = buildJsonObject { } + ) + + val transport = createTransport { request -> + assertEquals(HttpMethod.Post, request.method) + assertEquals("http://localhost:8080/mcp", request.url.toString()) + assertEquals(ContentType.Application.Json, request.body.contentType) + + val body = (request.body as TextContent).text + val decodedMessage = McpJson.decodeFromString(body) + assertEquals(message, decodedMessage) + + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + transport.start() + transport.send(message) + transport.close() + } + + @Test + fun testStoreSessionId() = runTest { + val initMessage = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "initialize", + params = buildJsonObject { + put("clientInfo", buildJsonObject { + put("name", JsonPrimitive("test-client")) + put("version", JsonPrimitive("1.0")) + }) + put("protocolVersion", JsonPrimitive("2025-06-18")) + } + ) + + val transport = createTransport { request -> + when (val msg = McpJson.decodeFromString((request.body as TextContent).text)) { + is JSONRPCRequest if msg.method == "initialize" -> respond( + content = "", status = HttpStatusCode.OK, + headers = headersOf("mcp-session-id", "test-session-id") + ) + + is JSONRPCNotification if msg.method == "test" -> { + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + else -> error("Unexpected message: $msg") + } + } + + transport.start() + transport.send(initMessage) + + assertEquals("test-session-id", transport.sessionId) + + transport.send(JSONRPCNotification(method = "test")) + + transport.close() + } + + @Test + fun testTerminateSession() = runTest { +// transport.sessionId = "test-session-id" + + val transport = createTransport { request -> + assertEquals(HttpMethod.Delete, request.method) + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.OK + ) + } + + transport.start() + transport.terminateSession() + + assertNull(transport.sessionId) + transport.close() + } + + @Test + fun testTerminateSessionHandle405() = runTest { +// transport.sessionId = "test-session-id" + + val transport = createTransport { request -> + assertEquals(HttpMethod.Delete, request.method) + respond( + content = "", + status = HttpStatusCode.MethodNotAllowed + ) + } + + transport.start() + // Should not throw for 405 + assertDoesNotThrow { + transport.terminateSession() + } + + // Session ID should still be cleared + assertNull(transport.sessionId) + transport.close() + } + + @Test + fun testProtocolVersionHeader() = runTest { + val transport = createTransport { request -> + assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + transport.protocolVersion = "2025-06-18" + + transport.start() + transport.send(JSONRPCNotification(method = "test")) + transport.close() + } + + @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") + @Test + fun testNotificationSchemaE2E() = runTest { + val receivedMessages = mutableListOf() + var sseStarted = false + + val transport = createTransport { request -> + when (request.method) { + HttpMethod.Post if request.body.toString().contains("notifications/initialized") -> { + respond( + content = "", + status = HttpStatusCode.Accepted, + headers = headersOf("mcp-session-id", "notification-test-session") + ) + } + + // Handle SSE connection + HttpMethod.Get -> { + sseStarted = true + val sseContent = buildString { + // Server sends various notifications + appendLine("event: message") + appendLine("id: 1") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""") + appendLine() + + appendLine("event: message") + appendLine("id: 2") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resources/list_changed"}""") + appendLine() + + appendLine("event: message") + appendLine("id: 3") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/tools/list_changed"}""") + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, ContentType.Text.EventStream.toString() + ) + ) + } + + // Handle regular notifications + HttpMethod.Post -> { + respond( + content = "", + status = HttpStatusCode.Accepted + ) + } + + else -> respond("", HttpStatusCode.OK) + } + } + + transport.onMessage { message -> + receivedMessages.add(message) + } + + transport.start() + + // Test 1: Send initialized notification to trigger SSE + val initializedNotification = JSONRPCNotification( + method = "notifications/initialized", + params = buildJsonObject { + put("protocolVersion", JsonPrimitive("1.0")) + put("capabilities", buildJsonObject { + put("tools", JsonPrimitive(true)) + put("resources", JsonPrimitive(true)) + }) + } + ) + + transport.send(initializedNotification) + + // Verify SSE was triggered + assertTrue(sseStarted, "SSE should start after initialized notification") + + // Test 2: Verify received notifications + assertEquals(3, receivedMessages.size) + assertTrue(receivedMessages.all { it is JSONRPCNotification }) + + val notifications = receivedMessages.filterIsInstance() + + // Verify progress notification + val progressNotif = notifications[0] + assertEquals("notifications/progress", progressNotif.method) + val progressParams = progressNotif.params as JsonObject + assertEquals("upload-123", (progressParams["progressToken"] as JsonPrimitive).content) + assertEquals(50, (progressParams["progress"] as JsonPrimitive).content.toInt()) + + // Verify list changed notifications + assertEquals("notifications/resources/list_changed", notifications[1].method) + assertEquals("notifications/tools/list_changed", notifications[2].method) + + // Test 3: Send various client notifications + val clientNotifications = listOf( + JSONRPCNotification( + method = "notifications/progress", + params = buildJsonObject { + put("progressToken", JsonPrimitive("download-456")) + put("progress", JsonPrimitive(75)) + } + ), + JSONRPCNotification( + method = "notifications/cancelled", + params = buildJsonObject { + put("requestId", JsonPrimitive("req-789")) + put("reason", JsonPrimitive("user_cancelled")) + } + ), + JSONRPCNotification( + method = "notifications/message", + params = buildJsonObject { + put("level", JsonPrimitive("info")) + put("message", JsonPrimitive("Operation completed")) + put("data", buildJsonObject { + put("duration", JsonPrimitive(1234)) + }) + } + ) + ) + + // Send all client notifications + clientNotifications.forEach { notification -> + transport.send(notification) + } + + // Verify session ID is maintained + assertEquals("notification-test-session", transport.sessionId) + transport.close() + } + + @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") + @Test + fun testNotificationWithResumptionToken() = runTest { + var resumptionTokenReceived: String? = null + var lastEventIdSent: String? = null + + val transport = createTransport { request -> + // Capture Last-Event-ID header + lastEventIdSent = request.headers["Last-Event-ID"] + + when (request.method) { + HttpMethod.Get -> { + val sseContent = buildString { + appendLine("event: message") + appendLine("id: resume-100") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"${lastEventIdSent}"}}""") + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, ContentType.Text.EventStream.toString() + ) + ) + } + + else -> respond("", HttpStatusCode.Accepted) + } + } + + transport.start() + + // Send notification with resumption token + transport.send( + message = JSONRPCNotification( + method = "notifications/test", + params = buildJsonObject { + put("data", JsonPrimitive("test-data")) + } + ), + resumptionToken = "previous-token-99", + onResumptionToken = { token -> + resumptionTokenReceived = token + } + ) + + // Wait for response + delay(1.seconds) + + // Verify resumption token was sent in header + assertEquals("previous-token-99", lastEventIdSent) + + // Verify new resumption token was received + assertEquals("resume-100", resumptionTokenReceived) + transport.close() + } +}