Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 55 additions & 127 deletions mcp-server/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,44 @@ pub struct GenericServerHandler<B: McpBackend> {
middleware: MiddlewareStack,
/// Global subscription registry tracking subscribed resource URIs
/// Note: This is a simplified global implementation. For per-client
/// subscriptions, use a HashMap<ClientId, HashSet<String>> instead.
/// subscriptions, use a `HashMap<ClientId, HashSet<String>>` instead.
subscriptions: Arc<RwLock<HashSet<String>>>,
}

/// Helper to create a JSON-RPC response with a result
#[inline]
fn make_response(id: Option<NumberOrString>, result: serde_json::Value) -> Response {
Response {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}

/// Helper to create an empty JSON-RPC response (for void methods)
#[inline]
fn make_empty_response(id: Option<NumberOrString>) -> Response {
Response {
jsonrpc: "2.0".to_string(),
id,
result: Some(serde_json::Value::Object(Default::default())),
error: None,
}
}

/// Parse optional paginated params, defaulting to no cursor
#[inline]
fn parse_paginated_params(
params: serde_json::Value,
) -> std::result::Result<PaginatedRequestParam, Error> {
if params.is_null() {
Ok(PaginatedRequestParam { cursor: None })
} else {
serde_json::from_value(params).map_err(Error::from)
}
}

impl<B: McpBackend> GenericServerHandler<B> {
/// Create a new handler
pub fn new(
Expand Down Expand Up @@ -230,24 +264,13 @@ impl<B: McpBackend> GenericServerHandler<B> {

#[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
let params: PaginatedRequestParam = if request.params.is_null() {
PaginatedRequestParam { cursor: None }
} else {
serde_json::from_value(request.params)?
};

let params = parse_paginated_params(request.params)?;
let result = self
.backend
.list_tools(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

#[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
Expand Down Expand Up @@ -302,103 +325,56 @@ impl<B: McpBackend> GenericServerHandler<B> {
&self,
request: Request,
) -> std::result::Result<Response, Error> {
let params: PaginatedRequestParam = if request.params.is_null() {
PaginatedRequestParam { cursor: None }
} else {
serde_json::from_value(request.params)?
};

let params = parse_paginated_params(request.params)?;
let result = self
.backend
.list_resources(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;

let result = self
.backend
.read_resource(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_list_resource_templates(
&self,
request: Request,
) -> std::result::Result<Response, Error> {
let params: PaginatedRequestParam = if request.params.is_null() {
PaginatedRequestParam { cursor: None }
} else {
serde_json::from_value(request.params)?
};

let params = parse_paginated_params(request.params)?;
let result = self
.backend
.list_resource_templates(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
let params: PaginatedRequestParam = if request.params.is_null() {
PaginatedRequestParam { cursor: None }
} else {
serde_json::from_value(request.params)?
};

let params = parse_paginated_params(request.params)?;
let result = self
.backend
.list_prompts(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
let params: GetPromptRequestParam = serde_json::from_value(request.params)?;

let result = self
.backend
.get_prompt(params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
Expand All @@ -410,19 +386,11 @@ impl<B: McpBackend> GenericServerHandler<B> {

// Track subscription globally
let mut subs = self.subscriptions.write().await;
let is_new = subs.insert(uri.clone());
drop(subs);

if is_new {
if subs.insert(uri.clone()) {
debug!("Subscribed to resource: {}", uri);
}

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::Value::Object(Default::default())),
error: None,
})
Ok(make_empty_response(request.id))
}

async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
Expand All @@ -437,67 +405,33 @@ impl<B: McpBackend> GenericServerHandler<B> {

// Remove from subscription tracking
let mut subs = self.subscriptions.write().await;
let was_present = subs.remove(&uri);
drop(subs);

if was_present {
if subs.remove(&uri) {
debug!("Unsubscribed from resource: {}", uri);
}

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::Value::Object(Default::default())),
error: None,
})
Ok(make_empty_response(request.id))
}

async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
let params: CompleteRequestParam = serde_json::from_value(request.params)?;

let result = self.backend.complete(params).await.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
let params: ElicitationRequestParam = serde_json::from_value(request.params)?;

let result = self.backend.elicit(params).await.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::to_value(result)?),
error: None,
})
Ok(make_response(request.id, serde_json::to_value(result)?))
}

async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
let params: SetLevelRequestParam = serde_json::from_value(request.params)?;

self.backend.set_level(params).await.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::Value::Object(Default::default())),
error: None,
})
Ok(make_empty_response(request.id))
}

async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
Ok(Response {
jsonrpc: "2.0".to_string(),
id: _request.id,
result: Some(serde_json::Value::Object(Default::default())),
error: None,
})
async fn handle_ping(&self, request: Request) -> std::result::Result<Response, Error> {
Ok(make_empty_response(request.id))
}

async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
Expand All @@ -506,13 +440,7 @@ impl<B: McpBackend> GenericServerHandler<B> {
.handle_custom_method(&request.method, request.params)
.await
.map_err(|e| e.into())?;

Ok(Response {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(result),
error: None,
})
Ok(make_response(request.id, result))
}
}

Expand Down