diff --git a/mcp-server/src/handler.rs b/mcp-server/src/handler.rs index 1ee8d7d..65eb61f 100644 --- a/mcp-server/src/handler.rs +++ b/mcp-server/src/handler.rs @@ -71,10 +71,44 @@ pub struct GenericServerHandler { middleware: MiddlewareStack, /// Global subscription registry tracking subscribed resource URIs /// Note: This is a simplified global implementation. For per-client - /// subscriptions, use a HashMap> instead. + /// subscriptions, use a `HashMap>` instead. subscriptions: Arc>>, } +/// Helper to create a JSON-RPC response with a result +#[inline] +fn make_response(id: Option, result: serde_json::Value) -> Response { + Response { + jsonrpc: "2.0".to_string(), + id, + result: Some(result), + error: None, + } +} + +/// Helper to create an empty JSON-RPC response (for void methods) +#[inline] +fn make_empty_response(id: Option) -> Response { + Response { + jsonrpc: "2.0".to_string(), + id, + result: Some(serde_json::Value::Object(Default::default())), + error: None, + } +} + +/// Parse optional paginated params, defaulting to no cursor +#[inline] +fn parse_paginated_params( + params: serde_json::Value, +) -> std::result::Result { + if params.is_null() { + Ok(PaginatedRequestParam { cursor: None }) + } else { + serde_json::from_value(params).map_err(Error::from) + } +} + impl GenericServerHandler { /// Create a new handler pub fn new( @@ -230,24 +264,13 @@ impl GenericServerHandler { #[instrument(skip(self, request), fields(mcp.method = "tools/list"))] async fn handle_list_tools(&self, request: Request) -> std::result::Result { - let params: PaginatedRequestParam = if request.params.is_null() { - PaginatedRequestParam { cursor: None } - } else { - serde_json::from_value(request.params)? - }; - + let params = parse_paginated_params(request.params)?; let result = self .backend .list_tools(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } #[instrument(skip(self, request), fields(mcp.method = "tools/call"))] @@ -302,103 +325,56 @@ impl GenericServerHandler { &self, request: Request, ) -> std::result::Result { - let params: PaginatedRequestParam = if request.params.is_null() { - PaginatedRequestParam { cursor: None } - } else { - serde_json::from_value(request.params)? - }; - + let params = parse_paginated_params(request.params)?; let result = self .backend .list_resources(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_read_resource(&self, request: Request) -> std::result::Result { let params: ReadResourceRequestParam = serde_json::from_value(request.params)?; - let result = self .backend .read_resource(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_list_resource_templates( &self, request: Request, ) -> std::result::Result { - let params: PaginatedRequestParam = if request.params.is_null() { - PaginatedRequestParam { cursor: None } - } else { - serde_json::from_value(request.params)? - }; - + let params = parse_paginated_params(request.params)?; let result = self .backend .list_resource_templates(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_list_prompts(&self, request: Request) -> std::result::Result { - let params: PaginatedRequestParam = if request.params.is_null() { - PaginatedRequestParam { cursor: None } - } else { - serde_json::from_value(request.params)? - }; - + let params = parse_paginated_params(request.params)?; let result = self .backend .list_prompts(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_get_prompt(&self, request: Request) -> std::result::Result { let params: GetPromptRequestParam = serde_json::from_value(request.params)?; - let result = self .backend .get_prompt(params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_subscribe(&self, request: Request) -> std::result::Result { @@ -410,19 +386,11 @@ impl GenericServerHandler { // Track subscription globally let mut subs = self.subscriptions.write().await; - let is_new = subs.insert(uri.clone()); - drop(subs); - - if is_new { + if subs.insert(uri.clone()) { debug!("Subscribed to resource: {}", uri); } - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::Value::Object(Default::default())), - error: None, - }) + Ok(make_empty_response(request.id)) } async fn handle_unsubscribe(&self, request: Request) -> std::result::Result { @@ -437,67 +405,33 @@ impl GenericServerHandler { // Remove from subscription tracking let mut subs = self.subscriptions.write().await; - let was_present = subs.remove(&uri); - drop(subs); - - if was_present { + if subs.remove(&uri) { debug!("Unsubscribed from resource: {}", uri); } - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::Value::Object(Default::default())), - error: None, - }) + Ok(make_empty_response(request.id)) } async fn handle_complete(&self, request: Request) -> std::result::Result { let params: CompleteRequestParam = serde_json::from_value(request.params)?; - let result = self.backend.complete(params).await.map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_elicit(&self, request: Request) -> std::result::Result { let params: ElicitationRequestParam = serde_json::from_value(request.params)?; - let result = self.backend.elicit(params).await.map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::to_value(result)?), - error: None, - }) + Ok(make_response(request.id, serde_json::to_value(result)?)) } async fn handle_set_level(&self, request: Request) -> std::result::Result { let params: SetLevelRequestParam = serde_json::from_value(request.params)?; - self.backend.set_level(params).await.map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(serde_json::Value::Object(Default::default())), - error: None, - }) + Ok(make_empty_response(request.id)) } - async fn handle_ping(&self, _request: Request) -> std::result::Result { - Ok(Response { - jsonrpc: "2.0".to_string(), - id: _request.id, - result: Some(serde_json::Value::Object(Default::default())), - error: None, - }) + async fn handle_ping(&self, request: Request) -> std::result::Result { + Ok(make_empty_response(request.id)) } async fn handle_custom_method(&self, request: Request) -> std::result::Result { @@ -506,13 +440,7 @@ impl GenericServerHandler { .handle_custom_method(&request.method, request.params) .await .map_err(|e| e.into())?; - - Ok(Response { - jsonrpc: "2.0".to_string(), - id: request.id, - result: Some(result), - error: None, - }) + Ok(make_response(request.id, result)) } }