From 6228efe1d34bea1df7fe87c20306a1c8e316ec3f Mon Sep 17 00:00:00 2001 From: Melvin Biamont Date: Wed, 9 Jul 2025 18:14:18 -0300 Subject: [PATCH] Add RegisteredTool context --- .../kotlin/sdk/server/Server.kt | 20 ++++---- src/jvmTest/kotlin/server/ServerTest.kt | 50 +++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index e98c4ced..ce0f980e 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -197,7 +197,7 @@ public open class Server( description: String, inputSchema: Tool.Input = Tool.Input(), toolAnnotations: ToolAnnotations? = null, - handler: suspend (CallToolRequest) -> CallToolResult + handler: suspend Server.(CallToolRequest) -> CallToolResult ) { if (capabilities.tools == null) { logger.error { "Failed to add tool '$name': Server does not support tools capability" } @@ -285,7 +285,7 @@ public open class Server( * @param promptProvider A suspend function that returns the prompt content when requested by the client. * @throws IllegalStateException If the server does not support prompts. */ - public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) { + public fun addPrompt(prompt: Prompt, promptProvider: suspend Server.(GetPromptRequest) -> GetPromptResult) { if (capabilities.prompts == null) { logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") @@ -307,7 +307,7 @@ public open class Server( name: String, description: String? = null, arguments: List? = null, - promptProvider: suspend (GetPromptRequest) -> GetPromptResult + promptProvider: suspend Server.(GetPromptRequest) -> GetPromptResult ) { val prompt = Prompt(name = name, description = description, arguments = arguments) addPrompt(prompt, promptProvider) @@ -398,7 +398,7 @@ public open class Server( name: String, description: String, mimeType: String = "text/html", - readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + readHandler: suspend Server.(ReadResourceRequest) -> ReadResourceResult ) { if (capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } @@ -604,7 +604,7 @@ public open class Server( throw IllegalArgumentException("Tool not found: ${request.name}") } logger.trace { "Executing tool ${request.name} with input: ${request.arguments}" } - return tool.handler(request) + return tool.handler.invoke(this, request) } private suspend fun handleListPrompts(): ListPromptsResult { @@ -619,7 +619,7 @@ public open class Server( logger.error { "Prompt not found: ${request.name}" } throw IllegalArgumentException("Prompt not found: ${request.name}") } - return prompt.messageProvider(request) + return prompt.messageProvider.invoke(this, request) } private suspend fun handleListResources(): ListResourcesResult { @@ -634,7 +634,7 @@ public open class Server( logger.error { "Resource not found: ${request.uri}" } throw IllegalArgumentException("Resource not found: ${request.uri}") } - return resource.readHandler(request) + return resource.readHandler.invoke(this, request) } private suspend fun handleListResourceTemplates(): ListResourceTemplatesResult { @@ -775,7 +775,7 @@ public open class Server( */ public data class RegisteredTool( val tool: Tool, - val handler: suspend (CallToolRequest) -> CallToolResult + val handler: suspend Server.(CallToolRequest) -> CallToolResult ) /** @@ -786,7 +786,7 @@ public data class RegisteredTool( */ public data class RegisteredPrompt( val prompt: Prompt, - val messageProvider: suspend (GetPromptRequest) -> GetPromptResult + val messageProvider: suspend Server.(GetPromptRequest) -> GetPromptResult ) /** @@ -797,5 +797,5 @@ public data class RegisteredPrompt( */ public data class RegisteredResource( val resource: Resource, - val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + val readHandler: suspend Server.(ReadResourceRequest) -> ReadResourceResult ) diff --git a/src/jvmTest/kotlin/server/ServerTest.kt b/src/jvmTest/kotlin/server/ServerTest.kt index 35e07741..cfdb5200 100644 --- a/src/jvmTest/kotlin/server/ServerTest.kt +++ b/src/jvmTest/kotlin/server/ServerTest.kt @@ -1,9 +1,12 @@ package server +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.Prompt import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification @@ -18,8 +21,11 @@ import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals @@ -468,4 +474,48 @@ class ServerTest { } assertEquals("Server does not support resources capability.", exception.message) } + + @Test + fun `sendLoggingMessage should throw when logging capability is disabled`() = runTest { + // Create server with tools capability + val serverOptions = ServerOptions( + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(null), + logging = null, + ) + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions + ) + + //Register a tool + server.addTool("test-tool", "Represent a tool example") { + // Verify that sending a logging message throws an exception + val exception = assertThrows { + sendLoggingMessage( + LoggingMessageNotification(level = LoggingLevel.alert, data = JsonObject(mapOf("progress" to JsonPrimitive(10)))) + ) + } + assertEquals("Server does not support logging (required for notifications/message)", exception.message) + + CallToolResult( + content = listOf(TextContent("Tool completed successfully.")) + ) + } + + // Setup client + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + ) + + // Connect client and server + listOf( + launch { server.connect(serverTransport) }, + launch { client.connect(clientTransport) }, + ).joinAll() + + client.request(CallToolRequest(name = "test-tool")) + } }