Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tracing = { version = "0.1" }
tokio-util = { version = "0.7" }
pin-project-lite = "0.2"
paste = { version = "1", optional = true }

async-trait = "0.1"
# oauth2 support
oauth2 = { version = "5.0", optional = true }

Expand Down
4 changes: 4 additions & 0 deletions crates/rmcp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum RmcpError {
error: Box<dyn std::error::Error + Send + Sync>,
},
// and cancellation shouldn't be an error?

// TODO: add more error variants as needed
#[error("Task error: {0}")]
TaskError(String),
}

impl RmcpError {
Expand Down
46 changes: 46 additions & 0 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ mod resource;
pub mod router;
pub mod tool;
pub mod wrapper;

impl<H: ServerHandler> Service<RoleServer> for H {
async fn handle_request(
&self,
request: <RoleServer as ServiceRole>::PeerReq,
context: RequestContext<RoleServer>,
) -> Result<<RoleServer as ServiceRole>::Resp, McpError> {
// Pre-dispatch: check task meta and optionally enqueue as task
if context.meta.get_task().is_some() {
// Allow handler to decide whether and how to enqueue task
if let Some(result) = self.enqueue_task(&request, context.clone()).await? {
return Ok(result);
}
}

match request {
ClientRequest::InitializeRequest(request) => self
.initialize(request.params, context)
Expand Down Expand Up @@ -68,6 +77,14 @@ impl<H: ServerHandler> Service<RoleServer> for H {
.list_tools(request.params, context)
.await
.map(ServerResult::ListToolsResult),
ClientRequest::ListTasksRequest(request) => self
.list_tasks(request.params, context)
.await
.map(ServerResult::ListTasksResult),
ClientRequest::GetTaskInfoRequest(request) => self
.get_task_info(request.params, context)
.await
.map(ServerResult::GetTaskInfoResult),
}
}

Expand Down Expand Up @@ -100,6 +117,19 @@ impl<H: ServerHandler> Service<RoleServer> for H {

#[allow(unused_variables)]
pub trait ServerHandler: Sized + Send + Sync + 'static {
/// Optional pre-dispatch hook to enqueue incoming request as a background task.
/// Default: do nothing and return None.
/// Implementors that also act as an OperationHandler may override this to:
/// - Inspect `context.meta` (e.g., key "modelcontextprotocol.io/task")
/// - Build an operation and submit to a task manager
/// - Return an immediate ServerResult (e.g., EmptyResult or a Task ack)
fn enqueue_task(
&self,
_request: &ClientRequest,
_context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<Option<ServerResult>, McpError>> + Send + '_ {
std::future::ready(Ok(None))
}
fn ping(
&self,
context: RequestContext<RoleServer>,
Expand Down Expand Up @@ -228,4 +258,20 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}

fn list_tasks(
&self,
request: Option<PaginatedRequestParam>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
}

fn get_task_info(
&self,
request: GetTaskInfoParam,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
}
}
1 change: 1 addition & 0 deletions crates/rmcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pub use service::{RoleClient, serve_client};
pub use service::{RoleServer, serve_server};

pub mod handler;
pub mod task_manager;
pub mod transport;

// re-export
Expand Down
49 changes: 48 additions & 1 deletion crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod meta;
mod prompt;
mod resource;
mod serde_impl;
mod task;
mod tool;
pub use annotated::*;
pub use capabilities::*;
Expand All @@ -19,6 +20,7 @@ pub use prompt::*;
pub use resource::*;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Value;
pub use task::*;
pub use tool::*;

/// A JSON object type alias for convenient handling of JSON data.
Expand Down Expand Up @@ -1654,6 +1656,23 @@ pub struct GetPromptResult {
pub messages: Vec<PromptMessage>,
}

// =============================================================================
// TASK MANAGEMENT
// =============================================================================

const_string!(GetTaskInfoMethod = "tasks/get");
pub type GetTaskInfoRequest = Request<GetTaskInfoMethod, GetTaskInfoParam>;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskInfoParam {
pub task_id: String,
}

const_string!(ListTasksMethod = "tasks/list");
pub type ListTasksRequest = RequestOptionalParam<ListTasksMethod, PaginatedRequestParam>;

// =============================================================================
// MESSAGE TYPE UNIONS
// =============================================================================
Expand Down Expand Up @@ -1720,7 +1739,9 @@ ts_union!(
| SubscribeRequest
| UnsubscribeRequest
| CallToolRequest
| ListToolsRequest;
| ListToolsRequest
| GetTaskInfoRequest
| ListTasksRequest;
);

impl ClientRequest {
Expand All @@ -1739,6 +1760,8 @@ impl ClientRequest {
ClientRequest::UnsubscribeRequest(r) => r.method.as_str(),
ClientRequest::CallToolRequest(r) => r.method.as_str(),
ClientRequest::ListToolsRequest(r) => r.method.as_str(),
ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(),
ClientRequest::ListTasksRequest(r) => r.method.as_str(),
}
}
}
Expand Down Expand Up @@ -1794,6 +1817,8 @@ ts_union!(
| CallToolResult
| ListToolsResult
| CreateElicitationResult
| GetTaskInfoResult
| ListTasksResult
| EmptyResult
;
);
Expand All @@ -1804,6 +1829,28 @@ impl ServerResult {
}
}

// =============================================================================
// TASK RESULT TYPES (Server responses for task queries)
// =============================================================================
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskInfoResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<crate::model::Task>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct ListTasksResult {
pub tasks: Vec<crate::model::Task>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
}

pub type ServerJsonRpcMessage = JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>;

impl TryInto<CancelledNotification> for ServerNotification {
Expand Down
9 changes: 9 additions & 0 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ variant_extension! {
UnsubscribeRequest
CallToolRequest
ListToolsRequest
GetTaskInfoRequest
ListTasksRequest
}
}

Expand Down Expand Up @@ -103,6 +105,7 @@ variant_extension! {
#[serde(transparent)]
pub struct Meta(pub JsonObject);
const PROGRESS_TOKEN_FIELD: &str = "progressToken";
const TASK_FIELD: &str = "modelcontextprotocol.io/task";
impl Meta {
pub fn new() -> Self {
Self(JsonObject::new())
Expand Down Expand Up @@ -133,6 +136,12 @@ impl Meta {
})
}

pub fn get_task(&self) -> Option<String> {
self.0
.get(TASK_FIELD)
.and_then(|v| v.as_str().map(|s| s.to_string()))
}

pub fn set_progress_token(&mut self, token: ProgressToken) {
match token.0 {
NumberOrString::String(ref s) => self.0.insert(
Expand Down
Loading
Loading