Skip to content

Add RegisteredTool receiver #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public open class Server(
description: String,
inputSchema: Tool.Input = Tool.Input(),
toolAnnotations: ToolAnnotations? = null,
handler: suspend (CallToolRequest) -> CallToolResult
handler: suspend Server.(CallToolRequest) -> CallToolResult
) {
if (capabilities.tools == null) {
logger.error { "Failed to add tool '$name': Server does not support tools capability" }
Expand Down Expand Up @@ -285,7 +285,7 @@ public open class Server(
* @param promptProvider A suspend function that returns the prompt content when requested by the client.
* @throws IllegalStateException If the server does not support prompts.
*/
public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) {
public fun addPrompt(prompt: Prompt, promptProvider: suspend Server.(GetPromptRequest) -> GetPromptResult) {
if (capabilities.prompts == null) {
logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" }
throw IllegalStateException("Server does not support prompts capability.")
Expand All @@ -307,7 +307,7 @@ public open class Server(
name: String,
description: String? = null,
arguments: List<PromptArgument>? = null,
promptProvider: suspend (GetPromptRequest) -> GetPromptResult
promptProvider: suspend Server.(GetPromptRequest) -> GetPromptResult
) {
val prompt = Prompt(name = name, description = description, arguments = arguments)
addPrompt(prompt, promptProvider)
Expand Down Expand Up @@ -398,7 +398,7 @@ public open class Server(
name: String,
description: String,
mimeType: String = "text/html",
readHandler: suspend (ReadResourceRequest) -> ReadResourceResult
readHandler: suspend Server.(ReadResourceRequest) -> ReadResourceResult
) {
if (capabilities.resources == null) {
logger.error { "Failed to add resource '$name': Server does not support resources capability" }
Expand Down Expand Up @@ -604,7 +604,7 @@ public open class Server(
throw IllegalArgumentException("Tool not found: ${request.name}")
}
logger.trace { "Executing tool ${request.name} with input: ${request.arguments}" }
return tool.handler(request)
return tool.handler.invoke(this, request)
}

private suspend fun handleListPrompts(): ListPromptsResult {
Expand All @@ -619,7 +619,7 @@ public open class Server(
logger.error { "Prompt not found: ${request.name}" }
throw IllegalArgumentException("Prompt not found: ${request.name}")
}
return prompt.messageProvider(request)
return prompt.messageProvider.invoke(this, request)
}

private suspend fun handleListResources(): ListResourcesResult {
Expand All @@ -634,7 +634,7 @@ public open class Server(
logger.error { "Resource not found: ${request.uri}" }
throw IllegalArgumentException("Resource not found: ${request.uri}")
}
return resource.readHandler(request)
return resource.readHandler.invoke(this, request)
}

private suspend fun handleListResourceTemplates(): ListResourceTemplatesResult {
Expand Down Expand Up @@ -775,7 +775,7 @@ public open class Server(
*/
public data class RegisteredTool(
val tool: Tool,
val handler: suspend (CallToolRequest) -> CallToolResult
val handler: suspend Server.(CallToolRequest) -> CallToolResult
)

/**
Expand All @@ -786,7 +786,7 @@ public data class RegisteredTool(
*/
public data class RegisteredPrompt(
val prompt: Prompt,
val messageProvider: suspend (GetPromptRequest) -> GetPromptResult
val messageProvider: suspend Server.(GetPromptRequest) -> GetPromptResult
)

/**
Expand All @@ -797,5 +797,5 @@ public data class RegisteredPrompt(
*/
public data class RegisteredResource(
val resource: Resource,
val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult
val readHandler: suspend Server.(ReadResourceRequest) -> ReadResourceResult
)
50 changes: 50 additions & 0 deletions src/jvmTest/kotlin/server/ServerTest.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package server

import io.modelcontextprotocol.kotlin.sdk.CallToolRequest
import io.modelcontextprotocol.kotlin.sdk.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.GetPromptResult
import io.modelcontextprotocol.kotlin.sdk.Implementation
import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.Method
import io.modelcontextprotocol.kotlin.sdk.Prompt
import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification
Expand All @@ -18,8 +21,11 @@ import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import kotlin.test.assertEquals
Expand Down Expand Up @@ -468,4 +474,48 @@ class ServerTest {
}
assertEquals("Server does not support resources capability.", exception.message)
}

@Test
fun `sendLoggingMessage should throw when logging capability is disabled`() = runTest {
// Create server with tools capability
val serverOptions = ServerOptions(
capabilities = ServerCapabilities(
tools = ServerCapabilities.Tools(null),
logging = null,
)
)
val server = Server(
Implementation(name = "test server", version = "1.0"),
serverOptions
)

//Register a tool
server.addTool("test-tool", "Represent a tool example") {
// Verify that sending a logging message throws an exception
val exception = assertThrows<IllegalStateException> {
sendLoggingMessage(
LoggingMessageNotification(level = LoggingLevel.alert, data = JsonObject(mapOf("progress" to JsonPrimitive(10))))
)
}
assertEquals("Server does not support logging (required for notifications/message)", exception.message)

CallToolResult(
content = listOf(TextContent("Tool completed successfully."))
)
}

// Setup client
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
val client = Client(
clientInfo = Implementation(name = "test client", version = "1.0"),
)

// Connect client and server
listOf(
launch { server.connect(serverTransport) },
launch { client.connect(clientTransport) },
).joinAll()

client.request<CallToolResult>(CallToolRequest(name = "test-tool"))
}
}