From f5bb9b0e2ffbaf26c36bce04ffda5f3d52209ca3 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 14:41:33 +0200 Subject: [PATCH 01/12] feat: add actix-web support for SSE server transport - Add actix-web as an optional web framework alongside Axum - Refactor SSE server into modular structure with common types - Create actix_web implementation with identical API to Axum - Add comprehensive unit and integration tests for both implementations - Make actix-web the default when enabled (Axum still available as AxumSseServer) - Add working actix-web example (counter_sse_actix.rs) - Fix calculator test model to use #[tool_handler] macro - Verified 100% protocol compatibility with JavaScript and Python MCP clients The implementation maintains full backwards compatibility - when only the axum feature is enabled, the original Axum implementation is used. --- crates/rmcp/Cargo.toml | 9 +- crates/rmcp/src/transport.rs | 4 +- .../src/transport/sse_server/actix_impl.rs | 603 ++++++++++++++++++ .../axum_impl.rs} | 196 +++++- .../rmcp/src/transport/sse_server/common.rs | 19 + crates/rmcp/src/transport/sse_server/mod.rs | 27 + crates/rmcp/tests/common/calculator.rs | 3 +- crates/rmcp/tests/test_sse_server.rs | 250 ++++++++ examples/servers/Cargo.toml | 10 + examples/servers/src/counter_sse_actix.rs | 52 ++ 10 files changed, 1157 insertions(+), 16 deletions(-) create mode 100644 crates/rmcp/src/transport/sse_server/actix_impl.rs rename crates/rmcp/src/transport/{sse_server.rs => sse_server/axum_impl.rs} (61%) create mode 100644 crates/rmcp/src/transport/sse_server/common.rs create mode 100644 crates/rmcp/src/transport/sse_server/mod.rs create mode 100644 crates/rmcp/tests/test_sse_server.rs create mode 100644 examples/servers/src/counter_sse_actix.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 2b9126d1..e7dd3365 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -55,12 +55,15 @@ process-wrap = { version = "8.2", features = ["tokio1"], optional = true } # for http-server transport axum = { version = "0.8", features = [], optional = true } +actix-web = { version = "4", optional = true } +actix-rt = { version = "2", optional = true } rand = { version = "0.9", optional = true } tokio-stream = { version = "0.1", optional = true } uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } +async-stream = { version = "0.3", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] @@ -70,10 +73,13 @@ chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", default-features = false, features = ["serde", "clock", "std", "oldtime"] } [features] -default = ["base64", "macros", "server"] +default = ["base64", "macros", "server", "axum"] client = ["dep:tokio-stream"] server = ["transport-async-rw", "dep:schemars"] macros = ["dep:rmcp-macros", "dep:paste"] +# Web framework features +axum = ["dep:axum"] +actix-web = ["dep:actix-web", "dep:actix-rt", "dep:async-stream"] # reqwest http client __reqwest = ["dep:reqwest"] @@ -116,7 +122,6 @@ transport-sse-server = [ "transport-async-rw", "transport-worker", "server-side-http", - "dep:axum", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 20e6ce75..ab02abec 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -105,8 +105,8 @@ pub use sse_client::SseClientTransport; #[cfg(feature = "transport-sse-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] pub mod sse_server; -#[cfg(feature = "transport-sse-server")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] +#[cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))))] pub use sse_server::SseServer; #[cfg(feature = "auth")] diff --git a/crates/rmcp/src/transport/sse_server/actix_impl.rs b/crates/rmcp/src/transport/sse_server/actix_impl.rs new file mode 100644 index 00000000..e3fc2722 --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/actix_impl.rs @@ -0,0 +1,603 @@ +use std::{collections::HashMap, io, sync::Arc, time::Duration}; + +use actix_web::{ + HttpRequest, HttpResponse, Result, Scope, + error::ErrorInternalServerError, + web::{self, Bytes, Data, Json, Query}, +}; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::{CancellationToken, PollSender}; +use tracing::Instrument; + +use crate::{ + RoleServer, Service, + model::ClientJsonRpcMessage, + service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, +}; + +use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; + +type TxStore = + Arc>>>; +pub type TransportReceiver = ReceiverStream>; + +#[derive(Clone, Debug)] +struct AppData { + txs: TxStore, + transport_tx: tokio::sync::mpsc::UnboundedSender, + post_path: Arc, + sse_ping_interval: Duration, +} + +impl AppData { + pub fn new( + post_path: String, + sse_ping_interval: Duration, + ) -> ( + Self, + tokio::sync::mpsc::UnboundedReceiver, + ) { + let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); + ( + Self { + txs: Default::default(), + transport_tx, + post_path: post_path.into(), + sse_ping_interval, + }, + transport_rx, + ) + } +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + app_data: Data, + query: Query, + _req: HttpRequest, + message: Json, +) -> Result { + let session_id = &query.session_id; + tracing::debug!(session_id, ?message, "new client message"); + + let tx = { + let rg = app_data.txs.read().await; + rg.get(session_id.as_str()) + .ok_or_else(|| actix_web::error::ErrorNotFound("Session not found"))? + .clone() + }; + + // Note: In actix-web, we don't have direct access to modify extensions + // This would need a different approach for passing HTTP request context + + if tx.send(message.0).await.is_err() { + tracing::error!("send message error"); + return Err(actix_web::error::ErrorGone("Session closed")); + } + + Ok(HttpResponse::Accepted().finish()) +} + +async fn sse_handler( + app_data: Data, + _req: HttpRequest, +) -> Result { + let session = session_id(); + tracing::info!(%session, "sse connection"); + + let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); + let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64); + let to_client_tx_clone = to_client_tx.clone(); + + app_data.txs + .write() + .await + .insert(session.clone(), from_client_tx); + + let _session_id = session.clone(); + let stream = ReceiverStream::new(from_client_rx); + let sink = PollSender::new(to_client_tx); + let transport = SseServerTransport { + stream, + sink, + session_id: session.clone(), + tx_store: app_data.txs.clone(), + }; + + let transport_send_result = app_data.transport_tx.send(transport); + if transport_send_result.is_err() { + tracing::warn!("send transport out error"); + return Err(ErrorInternalServerError("Failed to send transport, server is closed")); + } + + let post_path = app_data.post_path.clone(); + let ping_interval = app_data.sse_ping_interval; + let session_for_stream = session.clone(); + + // Create SSE response stream + let sse_stream = async_stream::stream! { + // Send initial endpoint message + yield Ok::<_, actix_web::Error>(Bytes::from(format!( + "event: endpoint\ndata: {}?sessionId={}\n\n", + post_path, session_for_stream + ))); + + // Set up ping interval + let mut ping_interval = tokio::time::interval(ping_interval); + ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let mut rx = ReceiverStream::new(to_client_rx); + + loop { + tokio::select! { + Some(message) = rx.next() => { + match serde_json::to_string(&message) { + Ok(json) => { + yield Ok(Bytes::from(format!("event: message\ndata: {}\n\n", json))); + } + Err(e) => { + tracing::error!("Failed to serialize message: {}", e); + } + } + } + _ = ping_interval.tick() => { + yield Ok(Bytes::from(": ping\n\n")); + } + else => break, + } + } + }; + + // Clean up on disconnect + let app_data_clone = app_data.clone(); + let session_for_cleanup = session.clone(); + actix_rt::spawn(async move { + to_client_tx_clone.closed().await; + + let mut txs = app_data_clone.txs.write().await; + txs.remove(&session_for_cleanup); + tracing::debug!(%session_for_cleanup, "Closed session and cleaned up resources"); + }); + + Ok(HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(sse_stream)) +} + +pub struct SseServerTransport { + stream: ReceiverStream>, + sink: PollSender>, + session_id: SessionId, + tx_store: TxStore, +} + +impl Sink> for SseServerTransport { + type Error = io::Error; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_ready_unpin(cx) + .map_err(std::io::Error::other) + } + + fn start_send( + mut self: std::pin::Pin<&mut Self>, + item: TxJsonRpcMessage, + ) -> Result<(), Self::Error> { + self.sink + .start_send_unpin(item) + .map_err(std::io::Error::other) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_flush_unpin(cx) + .map_err(std::io::Error::other) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let inner_close_result = self + .sink + .poll_close_unpin(cx) + .map_err(std::io::Error::other); + if inner_close_result.is_ready() { + let session_id = self.session_id.clone(); + let tx_store = self.tx_store.clone(); + tokio::spawn(async move { + tx_store.write().await.remove(&session_id); + }); + } + inner_close_result + } +} + +impl Stream for SseServerTransport { + type Item = RxJsonRpcMessage; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } +} + +#[derive(Debug)] +pub struct SseServer { + transport_rx: Arc>>, + pub config: SseServerConfig, + app_data: Data, +} + +impl SseServer { + pub async fn serve(bind: std::net::SocketAddr) -> io::Result { + Self::serve_with_config(SseServerConfig { + bind, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }) + .await + } + + pub async fn serve_with_config(mut config: SseServerConfig) -> io::Result { + let bind_addr = config.bind; + let ct = config.ct.clone(); + + // First bind to get the actual address + let listener = std::net::TcpListener::bind(bind_addr)?; + let actual_addr = listener.local_addr()?; + listener.set_nonblocking(true)?; + + // Update config with actual address + config.bind = actual_addr; + let (sse_server, _) = Self::new(config); + let app_data = sse_server.app_data.clone(); + let sse_path = sse_server.config.sse_path.clone(); + let post_path = sse_server.config.post_path.clone(); + + let server = actix_web::HttpServer::new(move || { + actix_web::App::new() + .app_data(app_data.clone()) + .route(&sse_path, web::get().to(sse_handler)) + .route(&post_path, web::post().to(post_event_handler)) + }) + .listen(listener)? + .run(); + + let ct_child = ct.child_token(); + let server_handle = server.handle(); + + actix_rt::spawn(async move { + ct_child.cancelled().await; + tracing::info!("sse server cancelled"); + server_handle.stop(true).await; + }); + + actix_rt::spawn( + async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "sse server shutdown with error"); + } + } + .instrument(tracing::info_span!("sse-server", bind_address = %actual_addr)), + ); + + Ok(sse_server) + } + + pub fn new(config: SseServerConfig) -> (SseServer, Scope) { + let (app_data, transport_rx) = AppData::new( + config.post_path.clone(), + config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), + ); + + let sse_path = config.sse_path.clone(); + let post_path = config.post_path.clone(); + + let app_data = Data::new(app_data); + + let scope = web::scope("") + .app_data(app_data.clone()) + .route(&sse_path, web::get().to(sse_handler)) + .route(&post_path, web::post().to(post_event_handler)); + + let server = SseServer { + transport_rx: Arc::new(Mutex::new(transport_rx)), + config, + app_data, + }; + + (server, scope) + } + + pub fn with_service(self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + use crate::service::ServiceExt; + let ct = self.config.ct.clone(); + let transport_rx = self.transport_rx.clone(); + + actix_rt::spawn(async move { + while let Some(transport) = transport_rx.lock().await.recv().await { + let service = service_provider(); + let ct_child = ct.child_token(); + tokio::spawn(async move { + let server = service + .serve_with_ct(transport, ct_child) + .await + .map_err(std::io::Error::other)?; + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + self.config.ct.clone() + } + + /// This allows you to skip the initialization steps for incoming request. + pub fn with_service_directly(self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + let ct = self.config.ct.clone(); + let transport_rx = self.transport_rx.clone(); + + actix_rt::spawn(async move { + while let Some(transport) = transport_rx.lock().await.recv().await { + let service = service_provider(); + let ct_child = ct.child_token(); + tokio::spawn(async move { + let server = serve_directly_with_ct(service, transport, None, ct_child); + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + self.config.ct.clone() + } + + pub fn cancel(&self) { + self.config.ct.cancel(); + } + + pub async fn next_transport(&self) -> Option { + self.transport_rx.lock().await.recv().await + } +} + +impl Stream for SseServer { + type Item = SseServerTransport; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let mut rx = match self.transport_rx.try_lock() { + Ok(rx) => rx, + Err(_) => { + cx.waker().wake_by_ref(); + return std::task::Poll::Pending; + } + }; + rx.poll_recv(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::{SinkExt, StreamExt}; + use tokio::time::timeout; + + #[tokio::test] + async fn test_session_management() { + let (app_data, transport_rx) = AppData::new("/message".to_string(), Duration::from_secs(15)); + + // Create a session + let session_id = session_id(); + let (tx, _rx) = tokio::sync::mpsc::channel(64); + + // Insert session + app_data.txs.write().await.insert(session_id.clone(), tx); + + // Verify session exists + assert!(app_data.txs.read().await.contains_key(&session_id)); + + // Remove session + app_data.txs.write().await.remove(&session_id); + + // Verify session removed + assert!(!app_data.txs.read().await.contains_key(&session_id)); + + drop(transport_rx); + } + + #[actix_web::test] + async fn test_sse_server_creation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + let (sse_server, scope) = SseServer::new(config); + + assert_eq!(sse_server.config.sse_path, "/sse"); + assert_eq!(sse_server.config.post_path, "/message"); + + // Scope should be properly configured + drop(scope); // Just ensure it's created without panic + } + + #[tokio::test] + async fn test_transport_stream() { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let stream = ReceiverStream::new(rx); + let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); + let sink = PollSender::new(sink_tx); + + let mut transport = SseServerTransport { + stream, + sink, + session_id: session_id(), + tx_store: Default::default(), + }; + + // Test sending through transport + use crate::model::{ServerResult, EmptyResult, JsonRpcMessage}; + let msg: TxJsonRpcMessage = JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); + // For PollSender, we need to send through async context + transport.send(msg).await.unwrap(); + + // Should receive the message + let received = timeout(Duration::from_millis(100), sink_rx.recv()) + .await + .unwrap() + .unwrap(); + + match received { + TxJsonRpcMessage::::Response(_) => {}, + _ => panic!("Unexpected message type"), + } + + // Test receiving through transport + let client_msg: RxJsonRpcMessage = crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + } + ), + }); + tx.send(client_msg).await.unwrap(); + drop(tx); + + let received = timeout(Duration::from_millis(100), transport.next()) + .await + .unwrap() + .unwrap(); + + match received { + RxJsonRpcMessage::::Notification(_) => {}, + _ => panic!("Unexpected message type"), + } + } + + #[actix_web::test] + async fn test_post_event_handler_session_not_found() { + use actix_web::test; + + let (app_data, _) = AppData::new("/message".to_string(), Duration::from_secs(15)); + let app_data = Data::new(app_data); + + let query = PostEventQuery { + session_id: "non-existent".to_string(), + }; + + // Create a simple cancelled notification + let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + } + ), + }); + + let result = post_event_handler( + app_data, + Query(query), + test::TestRequest::default().to_http_request(), + Json(client_msg), + ).await; + + assert!(result.is_err()); + } + + #[actix_web::test] + async fn test_server_with_cancellation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let (sse_server, _) = SseServer::new(config); + + // Test that the cancellation token is properly connected + assert!(!ct.is_cancelled()); + ct.cancel(); + assert!(ct.is_cancelled()); + + // Verify server config + assert!(sse_server.config.ct.is_cancelled()); + } + + #[actix_web::test] + async fn test_sse_stream_generation() { + let (app_data, mut transport_rx) = AppData::new("/message".to_string(), Duration::from_secs(15)); + let app_data = Data::new(app_data); + + // Call SSE handler + let result = sse_handler( + app_data.clone(), + actix_web::test::TestRequest::default().to_http_request(), + ).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + + // Check response headers + assert_eq!(response.status(), actix_web::http::StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + // Verify a transport was created + let transport = transport_rx.try_recv(); + assert!(transport.is_ok()); + } +} \ No newline at end of file diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server/axum_impl.rs similarity index 61% rename from crates/rmcp/src/transport/sse_server.rs rename to crates/rmcp/src/transport/sse_server/axum_impl.rs index 15a65cb5..d3c9e9e3 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server/axum_impl.rs @@ -19,9 +19,10 @@ use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, - transport::common::server_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; +use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; + type TxStore = Arc>>>; pub type TransportReceiver = ReceiverStream>; @@ -214,14 +215,6 @@ impl Stream for SseServerTransport { } } -#[derive(Debug, Clone)] -pub struct SseServerConfig { - pub bind: SocketAddr, - pub sse_path: String, - pub post_path: String, - pub ct: CancellationToken, - pub sse_keep_alive: Option, -} #[derive(Debug)] pub struct SseServer { @@ -240,9 +233,11 @@ impl SseServer { }) .await } - pub async fn serve_with_config(config: SseServerConfig) -> io::Result { + pub async fn serve_with_config(mut config: SseServerConfig) -> io::Result { + let listener = tokio::net::TcpListener::bind(config.bind).await?; + // Update config with actual bound address (important when port is 0) + config.bind = listener.local_addr()?; let (sse_server, service) = Self::new(config); - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; let ct = sse_server.config.ct.child_token(); let server = axum::serve(listener, service).with_graceful_shutdown(async move { ct.cancelled().await; @@ -341,3 +336,182 @@ impl Stream for SseServer { self.transport_rx.poll_recv(cx) } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::{SinkExt, StreamExt}; + use tokio::time::timeout; + + #[tokio::test] + async fn test_session_management() { + let (app, transport_rx) = App::new("/message".to_string(), Duration::from_secs(15)); + + // Create a session + let session_id = session_id(); + let (tx, _rx) = tokio::sync::mpsc::channel(64); + + // Insert session + app.txs.write().await.insert(session_id.clone(), tx); + + // Verify session exists + assert!(app.txs.read().await.contains_key(&session_id)); + + // Remove session + app.txs.write().await.remove(&session_id); + + // Verify session removed + assert!(!app.txs.read().await.contains_key(&session_id)); + + drop(transport_rx); + } + + #[tokio::test] + async fn test_sse_server_creation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + let (sse_server, router) = SseServer::new(config); + + assert_eq!(sse_server.config.sse_path, "/sse"); + assert_eq!(sse_server.config.post_path, "/message"); + + // Router should be properly configured + drop(router); // Just ensure it's created without panic + } + + #[tokio::test] + async fn test_transport_stream() { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let stream = ReceiverStream::new(rx); + let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); + let sink = PollSender::new(sink_tx); + + let mut transport = SseServerTransport { + stream, + sink, + session_id: session_id(), + tx_store: Default::default(), + }; + + // Test sending through transport + use crate::model::{ServerResult, EmptyResult, JsonRpcMessage}; + let msg: TxJsonRpcMessage = JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); + // For PollSender, we need to send through async context + transport.send(msg).await.unwrap(); + + // Should receive the message + let received = timeout(Duration::from_millis(100), sink_rx.recv()) + .await + .unwrap() + .unwrap(); + + match received { + TxJsonRpcMessage::::Response(_) => {}, + _ => panic!("Unexpected message type"), + } + + // Test receiving through transport + let client_msg: RxJsonRpcMessage = crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + } + ), + }); + tx.send(client_msg).await.unwrap(); + drop(tx); + + let received = timeout(Duration::from_millis(100), transport.next()) + .await + .unwrap() + .unwrap(); + + match received { + RxJsonRpcMessage::::Notification(_) => {}, + _ => panic!("Unexpected message type"), + } + } + + #[tokio::test] + async fn test_post_event_handler_session_not_found() { + use axum::extract::{Query, State}; + use axum::Json; + use axum::http::Request; + + let (app, _) = App::new("/message".to_string(), Duration::from_secs(15)); + + let query = PostEventQuery { + session_id: "non-existent".to_string(), + }; + + // Create a minimal request parts + let request = Request::builder() + .method("POST") + .uri("/message") + .body(()) + .unwrap(); + let (parts, _) = request.into_parts(); + + // Create a simple cancelled notification + let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + } + ), + }); + + let result = post_event_handler( + State(app), + Query(query), + parts, + Json(client_msg), + ).await; + + assert_eq!(result, Err(StatusCode::NOT_FOUND)); + } + + #[tokio::test] + async fn test_server_with_cancellation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct_clone = config.ct.clone(); + let (mut sse_server, _) = SseServer::new(config); + + // Cancel immediately + ct_clone.cancel(); + + // next_transport should return None after cancellation + let transport = timeout(Duration::from_millis(100), sse_server.next_transport()).await; + assert!(transport.is_ok()); + assert!(transport.unwrap().is_none()); + } +} diff --git a/crates/rmcp/src/transport/sse_server/common.rs b/crates/rmcp/src/transport/sse_server/common.rs new file mode 100644 index 00000000..6cfc4f3e --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/common.rs @@ -0,0 +1,19 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tokio_util::sync::CancellationToken; + +pub type SessionId = Arc; + +pub fn session_id() -> SessionId { + uuid::Uuid::new_v4().to_string().into() +} + +#[derive(Debug, Clone)] +pub struct SseServerConfig { + pub bind: SocketAddr, + pub sse_path: String, + pub post_path: String, + pub ct: CancellationToken, + pub sse_keep_alive: Option, +} + +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); \ No newline at end of file diff --git a/crates/rmcp/src/transport/sse_server/mod.rs b/crates/rmcp/src/transport/sse_server/mod.rs new file mode 100644 index 00000000..68c9e29e --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/mod.rs @@ -0,0 +1,27 @@ +#[cfg(feature = "transport-sse-server")] +pub mod common; + +// When only axum is enabled +#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] +mod axum_impl; + +#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] +pub use axum_impl::*; + +// When actix-web is enabled (with or without axum) +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +mod actix_impl; + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +pub use actix_impl::*; + +// When both are enabled, also provide axum implementation under different name +#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "actix-web"))] +pub mod axum_impl; + +#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "actix-web"))] +pub use axum_impl::SseServer as AxumSseServer; + +// Re-export common types when transport-sse-server is enabled +#[cfg(feature = "transport-sse-server")] +pub use common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; \ No newline at end of file diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index 4f4fccee..50936d02 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -3,7 +3,7 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + schemars, tool, tool_router, tool_handler, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -51,6 +51,7 @@ impl Calculator { } } +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/crates/rmcp/tests/test_sse_server.rs b/crates/rmcp/tests/test_sse_server.rs new file mode 100644 index 00000000..18893ba6 --- /dev/null +++ b/crates/rmcp/tests/test_sse_server.rs @@ -0,0 +1,250 @@ +#![cfg(feature = "transport-sse-server")] + +use rmcp::{ + ServiceExt, + transport::{SseServer, sse_server::SseServerConfig}, +}; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +mod common; +use common::calculator::Calculator; + +async fn init() { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); +} + +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +#[tokio::test] +async fn test_axum_sse_server_basic() -> anyhow::Result<()> { + init().await; + + let config = SseServerConfig { + bind: "127.0.0.1:0".parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + #[cfg(not(feature = "actix-web"))] + let sse_server = SseServer::serve_with_config(config).await?; + #[cfg(feature = "actix-web")] + let sse_server = rmcp::transport::sse_server::AxumSseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + + let service_ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test that server is running by making a request + let client = reqwest::Client::new(); + let response = client + .get(format!("http://{}/sse", bind_addr)) + .header("Accept", "text/event-stream") + .send() + .await?; + + // SSE endpoint should return OK and start streaming + assert_eq!(response.status(), reqwest::StatusCode::OK); + + ct.cancel(); + service_ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +#[actix_web::test] +async fn test_actix_sse_server_basic() -> anyhow::Result<()> { + init().await; + + let config = SseServerConfig { + bind: "127.0.0.1:0".parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let sse_server = SseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + + let service_ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Test that server is running by making a request + let client = reqwest::Client::new(); + let response = client + .get(format!("http://{}/sse", bind_addr)) + .header("Accept", "text/event-stream") + .send() + .await?; + + // SSE endpoint should return OK and start streaming + assert_eq!(response.status(), reqwest::StatusCode::OK); + + ct.cancel(); + service_ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "transport-sse-client"))] +#[tokio::test] +async fn test_axum_client_server_integration() -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + #[cfg(not(feature = "actix-web"))] + let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + #[cfg(feature = "actix-web")] + let sse_server = rmcp::transport::sse_server::AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; + let client = ().serve(transport).await?; + + // Test basic operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + client.cancel().await?; + ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web", feature = "transport-sse-client"))] +#[actix_web::test] +async fn test_actix_client_server_integration() -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; + let client = ().serve(transport).await?; + + // Test basic operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + client.cancel().await?; + ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +#[tokio::test] +async fn test_axum_concurrent_clients() -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + const NUM_CLIENTS: usize = 5; + + #[cfg(not(feature = "actix-web"))] + let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + #[cfg(feature = "actix-web")] + let sse_server = rmcp::transport::sse_server::AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + let mut handles = vec![]; + + for i in 0..NUM_CLIENTS { + let addr = actual_addr; + let handle = tokio::spawn(async move { + let transport = SseClientTransport::start(format!("http://{}/sse", addr)).await?; + let client = ().serve(transport).await?; + + // Each client does some operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + tracing::info!("Client {} completed operations", i); + client.cancel().await?; + Ok::<(), anyhow::Error>(()) + }); + handles.push(handle); + } + + // Wait for all clients to complete + for handle in handles { + handle.await??; + } + + ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +#[actix_web::test] +async fn test_actix_concurrent_clients() -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + const NUM_CLIENTS: usize = 5; + + let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + let mut handles = vec![]; + + for i in 0..NUM_CLIENTS { + let addr = actual_addr; + let handle = tokio::spawn(async move { + let transport = SseClientTransport::start(format!("http://{}/sse", addr)).await?; + let client = ().serve(transport).await?; + + // Each client does some operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + tracing::info!("Client {} completed operations", i); + client.cancel().await?; + Ok::<(), anyhow::Error>(()) + }); + handles.push(handle); + } + + // Wait for all clients to complete + for handle in handles { + handle.await??; + } + + ct.cancel(); + Ok(()) +} \ No newline at end of file diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index a6acd646..5e2c4db6 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -33,6 +33,8 @@ tracing-subscriber = { version = "0.3", features = [ futures = "0.3" rand = { version = "0.9", features = ["std"] } axum = { version = "0.8", features = ["macros"] } +actix-web = { version = "4", optional = true } +actix-rt = { version = "2", optional = true } schemars = { version = "0.8", optional = true } reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" @@ -82,3 +84,11 @@ path = "src/counter_hyper_streamable_http.rs" [[example]] name = "servers_sampling_stdio" path = "src/sampling_stdio.rs" + +[[example]] +name = "servers_counter_sse_actix" +path = "src/counter_sse_actix.rs" +required-features = ["actix-web"] + +[features] +actix-web = ["dep:actix-web", "dep:actix-rt", "rmcp/actix-web"] diff --git a/examples/servers/src/counter_sse_actix.rs b/examples/servers/src/counter_sse_actix.rs new file mode 100644 index 00000000..532e9cb1 --- /dev/null +++ b/examples/servers/src/counter_sse_actix.rs @@ -0,0 +1,52 @@ +use rmcp::transport::sse_server::{SseServer, SseServerConfig}; +use tracing_subscriber::{ + layer::SubscriberExt, + util::SubscriberInitExt, + {self}, +}; +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let config = SseServerConfig { + bind: BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: tokio_util::sync::CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct_signal = config.ct.clone(); + + let sse_server = SseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let ct = sse_server.with_service(Counter::new); + + println!("\nšŸš€ SSE Server (actix-web) running at http://{}", bind_addr); + println!("šŸ“” SSE endpoint: http://{}/sse", bind_addr); + println!("šŸ“® Message endpoint: http://{}/message", bind_addr); + println!("\nPress Ctrl+C to stop the server\n"); + + // Set up Ctrl-C handler + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + println!("\nā¹ļø Shutting down..."); + ct_signal.cancel(); + }); + + // Wait for cancellation + ct.cancelled().await; + println!("āœ… Server stopped"); + Ok(()) +} \ No newline at end of file From 1389f34cb310e835204021ace871d15f77ceb815 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 14:58:40 +0200 Subject: [PATCH 02/12] fix: add missing transport-worker dependency to streamable-http-server The LocalSessionManager in streamable-http-server uses WorkerTransport, which requires the transport-worker feature. This was working accidentally in examples because transport-sse-server also includes transport-worker. Without this fix, using transport-streamable-http-server independently (without transport-sse-server) would fail to compile with unresolved import errors for WorkerTransport. --- crates/rmcp/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index e7dd3365..6420d107 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -130,6 +130,7 @@ transport-streamable-http-server = [ transport-streamable-http-server-session = [ "transport-async-rw", "dep:tokio-stream", + "transport-worker", ] # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] From 7c1a237c85820eda59a7363963f8b8613521020a Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 17:07:54 +0200 Subject: [PATCH 03/12] feat: add actix-web support for streamable HTTP server transport - Implement ActixStreamableHttpService with full feature parity to Axum implementation - Support GET (SSE streams), POST (requests), DELETE (session cleanup) endpoints - Add comprehensive logging matching SSE implementation - Use X-Accel-Buffering header constant for consistency - Create working example server (counter_streamable_http_actix.rs) - Refactor tower.rs to import StreamableHttpServerConfig from parent module - Add HEADER_X_ACCEL_BUFFERING constant to common http_header module - Update streamable_client.js to accept URL as command line argument - Configure example to bind to 127.0.0.1 for IPv4 compatibility The actix-web implementation provides identical functionality to the Axum version while following actix-web patterns and conventions. --- .../rmcp/src/transport/common/http_header.rs | 1 + .../src/transport/streamable_http_server.rs | 78 ++- .../streamable_http_server/actix_impl.rs | 479 ++++++++++++++++++ .../transport/streamable_http_server/tower.rs | 21 +- crates/rmcp/tests/test_with_js.rs | 74 ++- .../tests/test_with_js/streamable_client.js | 3 +- examples/servers/Cargo.toml | 4 + .../src/counter_streamable_http_actix.rs | 41 ++ 8 files changed, 671 insertions(+), 30 deletions(-) create mode 100644 crates/rmcp/src/transport/streamable_http_server/actix_impl.rs create mode 100644 examples/servers/src/counter_streamable_http_actix.rs diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 84bc7bfb..275c9f18 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -1,4 +1,5 @@ pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +pub const HEADER_X_ACCEL_BUFFERING: &str = "X-Accel-Buffering"; pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; pub const JSON_MIME_TYPE: &str = "application/json"; diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index 733fc5e5..12cfd173 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -1,8 +1,76 @@ +//! Streamable HTTP Server Transport Module +//! +//! This module provides streamable HTTP transport implementations for MCP. +//! +//! # Type Export Strategy +//! +//! This module exports framework-specific implementations with explicit names: +//! - `AxumStreamableHttpService` - The Axum-based streamable HTTP service implementation +//! - `ActixStreamableHttpService` - The actix-web-based streamable HTTP service implementation +//! +//! For convenience, a type alias `StreamableHttpService` is provided that resolves to: +//! - `ActixStreamableHttpService` when the `actix-web` feature is enabled +//! - `AxumStreamableHttpService` when only the `axum` feature is enabled +//! +//! # Examples +//! +//! Using the convenience alias (recommended for most use cases): +//! ```ignore +//! use rmcp::transport::StreamableHttpService; +//! let service = StreamableHttpService::new(|| Ok(handler), session_manager, config); +//! ``` +//! +//! Using explicit types (when you need a specific implementation): +//! ```ignore +//! #[cfg(feature = "axum")] +//! use rmcp::transport::AxumStreamableHttpService; +//! #[cfg(feature = "axum")] +//! let service = AxumStreamableHttpService::new(|| Ok(handler), session_manager, config); +//! ``` + pub mod session; -#[cfg(feature = "transport-streamable-http-server")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] + +use std::time::Duration; + +/// Configuration for the streamable HTTP server +#[derive(Debug, Clone)] +pub struct StreamableHttpServerConfig { + /// The ping message duration for SSE connections. + pub sse_keep_alive: Option, + /// If true, the server will create a session for each request and keep it alive. + pub stateful_mode: bool, +} + +impl Default for StreamableHttpServerConfig { + fn default() -> Self { + Self { + sse_keep_alive: Some(Duration::from_secs(15)), + stateful_mode: true, + } + } +} + +// Axum implementation +#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))))] pub mod tower; + +#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] +pub use tower::StreamableHttpService as AxumStreamableHttpService; + +// Actix-web implementation +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))))] +pub mod actix_impl; + +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +pub use actix_impl::StreamableHttpService as ActixStreamableHttpService; + +// Export the preferred implementation as StreamableHttpService (without generic parameters) +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +pub use actix_impl::StreamableHttpService; + +#[cfg(all(feature = "transport-streamable-http-server", feature = "axum", not(feature = "actix-web")))] +pub use tower::StreamableHttpService; + pub use session::{SessionId, SessionManager}; -#[cfg(feature = "transport-streamable-http-server")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; diff --git a/crates/rmcp/src/transport/streamable_http_server/actix_impl.rs b/crates/rmcp/src/transport/streamable_http_server/actix_impl.rs new file mode 100644 index 00000000..61ad1b25 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/actix_impl.rs @@ -0,0 +1,479 @@ +use std::sync::Arc; + +use actix_web::{ + HttpRequest, HttpResponse, Result, error::InternalError, http::{StatusCode, header}, middleware, web::{self, Bytes, Data}, +}; +use futures::{Stream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; + +use super::{StreamableHttpServerConfig, session::SessionManager}; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ClientRequest}, + serve_server, + service::serve_directly, + transport::{ + OneshotTransport, TransportAdapterIdentity, + common::{ + http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, HEADER_X_ACCEL_BUFFERING, JSON_MIME_TYPE, + }, + }, + }, +}; + +#[derive(Clone)] +pub struct StreamableHttpService { + pub config: StreamableHttpServerConfig, + session_manager: Arc, + service_factory: Arc Result + Send + Sync>, +} + +impl StreamableHttpService +where + S: crate::Service + Send + 'static, + M: SessionManager + 'static, +{ + pub fn new( + service_factory: impl Fn() -> Result + Send + Sync + 'static, + session_manager: Arc, + config: StreamableHttpServerConfig, + ) -> Self { + Self { + config, + session_manager, + service_factory: Arc::new(service_factory), + } + } + + fn get_service(&self) -> Result { + (self.service_factory)() + } + + /// Configure actix_web routes for the streamable HTTP server + pub fn configure(service: Arc) -> impl FnOnce(&mut web::ServiceConfig) { + move |cfg: &mut web::ServiceConfig| { + cfg.service( + web::scope("/") + .app_data(Data::new(service.clone())) + .wrap(middleware::NormalizePath::trim()) + .route("", web::get().to(Self::handle_get)) + .route("", web::post().to(Self::handle_post)) + .route("", web::delete().to(Self::handle_delete)) + ); + } + } + + async fn handle_get( + req: HttpRequest, + service: Data>>, + ) -> Result { + // Check accept header + let accept = req + .headers() + .get(header::ACCEPT) + .and_then(|h| h.to_str().ok()); + + if !accept.is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) { + return Ok(HttpResponse::NotAcceptable() + .body("Not Acceptable: Client must accept text/event-stream")); + } + + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + + let Some(session_id) = session_id else { + return Ok(HttpResponse::Unauthorized() + .body("Unauthorized: Session ID is required")); + }; + + tracing::debug!(%session_id, "GET request for SSE stream"); + + // Check if session exists + let has_session = service + .session_manager + .has_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + if !has_session { + return Ok(HttpResponse::Unauthorized() + .body("Unauthorized: Session not found")); + } + + // Check if last event id is provided + let last_event_id = req + .headers() + .get(HEADER_LAST_EVENT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + + // Get the appropriate stream + let sse_stream: std::pin::Pin + Send>> = if let Some(last_event_id) = last_event_id { + tracing::debug!(%session_id, %last_event_id, "Resuming stream from last event"); + Box::pin(service + .session_manager + .resume(&session_id, last_event_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?) + } else { + tracing::debug!(%session_id, "Creating standalone stream"); + Box::pin(service + .session_manager + .create_standalone_stream(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?) + }; + + // Convert to SSE format + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = sse_stream; + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(msg) = stream.next() => { + let data = serde_json::to_string(&msg.message) + .unwrap_or_else(|_| "{}".to_string()); + let mut output = String::new(); + if let Some(id) = msg.event_id { + output.push_str(&format!("id: {}\n", id)); + } + output.push_str(&format!("data: {}\n\n", data)); + yield Ok::<_, actix_web::Error>(Bytes::from(output)); + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type("text/event-stream") + .append_header(("Cache-Control", "no-cache")) + .append_header(("X-Accel-Buffering", "no")) + .streaming(sse_stream)) + } + + async fn handle_post( + req: HttpRequest, + body: Bytes, + service: Data>>, + ) -> Result { + // Check accept header + let accept = req + .headers() + .get(header::ACCEPT) + .and_then(|h| h.to_str().ok()); + + if !accept.is_some_and(|header| { + header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE) + }) { + return Ok(HttpResponse::NotAcceptable() + .body("Not Acceptable: Client must accept both application/json and text/event-stream")); + } + + // Check content type + let content_type = req + .headers() + .get(header::CONTENT_TYPE) + .and_then(|h| h.to_str().ok()); + + if !content_type.is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) { + return Ok(HttpResponse::UnsupportedMediaType() + .body("Unsupported Media Type: Content-Type must be application/json")); + } + + // Deserialize the message + let mut message: ClientJsonRpcMessage = serde_json::from_slice(&body) + .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))?; + + tracing::debug!(?message, "POST request with message"); + + if service.config.stateful_mode { + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()); + + if let Some(session_id) = session_id { + let session_id = session_id.to_owned().into(); + tracing::debug!(%session_id, "POST request with existing session"); + + let has_session = service + .session_manager + .has_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + if !has_session { + tracing::warn!(%session_id, "Session not found"); + return Ok(HttpResponse::Unauthorized() + .body("Unauthorized: Session not found")); + } + + // Note: In actix-web we can't inject request parts like in tower, + // but session_id is already available through headers + + match message { + ClientJsonRpcMessage::Request(_) => { + let stream = service + .session_manager + .create_stream(&session_id, message) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + // Convert to SSE format + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = Box::pin(stream); + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(msg) = stream.next() => { + let data = serde_json::to_string(&msg.message) + .unwrap_or_else(|_| "{}".to_string()); + let mut output = String::new(); + if let Some(id) = msg.event_id { + output.push_str(&format!("id: {}\n", id)); + } + output.push_str(&format!("data: {}\n\n", data)); + yield Ok::<_, actix_web::Error>(Bytes::from(output)); + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .streaming(sse_stream)) + } + ClientJsonRpcMessage::Notification(_) + | ClientJsonRpcMessage::Response(_) + | ClientJsonRpcMessage::Error(_) => { + // Handle notification + service + .session_manager + .accept_message(&session_id, message) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + Ok(HttpResponse::Accepted().finish()) + } + ClientJsonRpcMessage::BatchRequest(_) | ClientJsonRpcMessage::BatchResponse(_) => { + Ok(HttpResponse::NotImplemented() + .body("Batch requests are not supported yet")) + } + } + } else { + // No session id in stateful mode - create new session + tracing::debug!("POST request without session, creating new session"); + + let (session_id, transport) = service + .session_manager + .create_session() + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + tracing::info!(%session_id, "Created new session"); + + if let ClientJsonRpcMessage::Request(req) = &mut message { + if !matches!(req.request, ClientRequest::InitializeRequest(_)) { + return Ok(HttpResponse::UnprocessableEntity() + .body("Expected initialize request")); + } + } else { + return Ok(HttpResponse::UnprocessableEntity() + .body("Expected initialize request")); + } + + let service_instance = service + .get_service() + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + // Spawn a task to serve the session + tokio::spawn({ + let session_manager = service.session_manager.clone(); + let session_id = session_id.clone(); + async move { + let service = serve_server::( + service_instance, transport, + ) + .await; + match service { + Ok(service) => { + let _ = service.waiting().await; + } + Err(e) => { + tracing::error!("Failed to create service: {e}"); + } + } + let _ = session_manager + .close_session(&session_id) + .await + .inspect_err(|e| { + tracing::error!("Failed to close session {session_id}: {e}"); + }); + } + }); + + // Get initialize response + let response = service + .session_manager + .initialize_session(&session_id, message) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + // Return SSE stream with single response + let sse_stream = async_stream::stream! { + yield Ok::<_, actix_web::Error>(Bytes::from(format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string()) + ))); + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .append_header((HEADER_SESSION_ID, session_id.as_ref())) + .streaming(sse_stream)) + } + } else { + // Stateless mode + tracing::debug!("POST request in stateless mode"); + + match message { + ClientJsonRpcMessage::Request(request) => { + tracing::debug!(?request, "Processing request in stateless mode"); + + // In stateless mode, handle the request directly + let service_instance = service + .get_service() + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + let (transport, receiver) = + OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); + let service_handle = serve_directly(service_instance, transport, None); + + tokio::spawn(async move { + // Let the service process the request + let _ = service_handle.waiting().await; + }); + + // Convert receiver stream to SSE format + let sse_stream = ReceiverStream::new(receiver).map(|message| { + tracing::info!(?message); + let data = serde_json::to_string(&message) + .unwrap_or_else(|_| "{}".to_string()); + Ok::<_, actix_web::Error>(Bytes::from(format!("data: {}\n\n", data))) + }); + + // Add keep-alive if configured + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = Box::pin(sse_stream); + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(result) = stream.next() => { + match result { + Ok(data) => yield Ok(data), + Err(e) => yield Err(e), + } + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .streaming(sse_stream)) + } + _ => { + Ok(HttpResponse::UnprocessableEntity() + .body("Unexpected message type")) + } + } + } + } + + async fn handle_delete( + req: HttpRequest, + service: Data>>, + ) -> Result { + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + + let Some(session_id) = session_id else { + return Ok(HttpResponse::Unauthorized() + .body("Unauthorized: Session ID is required")); + }; + + tracing::debug!(%session_id, "DELETE request to close session"); + + // Close session + service + .session_manager + .close_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + tracing::info!(%session_id, "Session closed"); + + Ok(HttpResponse::NoContent().finish()) + } +} \ No newline at end of file diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 0ed0858e..87c9c083 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; +use std::{convert::Infallible, fmt::Display, sync::Arc}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; @@ -7,7 +7,7 @@ use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; -use super::session::SessionManager; +use super::{StreamableHttpServerConfig, session::SessionManager}; use crate::{ RoleServer, model::{ClientJsonRpcMessage, ClientRequest, GetExtensions}, @@ -27,23 +27,6 @@ use crate::{ }, }; -#[derive(Debug, Clone)] -pub struct StreamableHttpServerConfig { - /// The ping message duration for SSE connections. - pub sse_keep_alive: Option, - /// If true, the server will create a session for each request and keep it alive. - pub stateful_mode: bool, -} - -impl Default for StreamableHttpServerConfig { - fn default() -> Self { - Self { - sse_keep_alive: Some(Duration::from_secs(15)), - stateful_mode: true, - } - } -} - pub struct StreamableHttpService { pub config: StreamableHttpServerConfig, session_manager: Arc, diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 3f2761cd..c9faa4fd 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -5,10 +5,16 @@ use rmcp::{ ConfigureCommandExt, SseServer, StreamableHttpClientTransport, StreamableHttpServerConfig, TokioChildProcess, streamable_http_server::{ - session::local::LocalSessionManager, tower::StreamableHttpService, + session::local::LocalSessionManager, }, }, }; + +// Import framework-specific types +#[cfg(feature = "axum")] +use rmcp::transport::AxumStreamableHttpService; +#[cfg(feature = "actix-web")] +use rmcp::transport::ActixStreamableHttpService; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; @@ -16,7 +22,8 @@ use common::calculator::Calculator; const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; -const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8002"; +const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8002"; +const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8003"; #[tokio::test] async fn test_with_js_client() -> anyhow::Result<()> { @@ -78,8 +85,9 @@ async fn test_with_js_server() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "axum")] #[tokio::test] -async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { +async fn test_with_js_streamable_http_client_axum() -> anyhow::Result<()> { let _ = tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -94,8 +102,8 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .wait() .await?; - let service: StreamableHttpService = - StreamableHttpService::new( + let service: AxumStreamableHttpService = + AxumStreamableHttpService::new( || Ok(Calculator::new()), Default::default(), StreamableHttpServerConfig { @@ -125,6 +133,62 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "actix-web")] +#[actix_web::test] +async fn test_with_js_streamable_http_client_actix() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let service = std::sync::Arc::new(ActixStreamableHttpService::::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + }, + )); + + let server = actix_web::HttpServer::new(move || { + actix_web::App::new() + .wrap(actix_web::middleware::Logger::default()) + .service( + actix_web::web::scope("/mcp") + .configure(ActixStreamableHttpService::configure(service.clone())) + ) + }) + .bind(STREAMABLE_HTTP_ACTIX_BIND_ADDRESS)? + .run(); + + let server_handle = server.handle(); + let server_task = tokio::spawn(server); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + let exit_status = tokio::process::Command::new("node") + .arg("tests/test_with_js/streamable_client.js") + .arg(format!("http://{}/mcp/", STREAMABLE_HTTP_ACTIX_BIND_ADDRESS)) + .spawn()? + .wait() + .await?; + assert!(exit_status.success()); + + server_handle.stop(true).await; + let _ = server_task.await; + Ok(()) +} + #[tokio::test] async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { let _ = tracing_subscriber::registry() diff --git a/crates/rmcp/tests/test_with_js/streamable_client.js b/crates/rmcp/tests/test_with_js/streamable_client.js index 99826131..fb86b7ac 100644 --- a/crates/rmcp/tests/test_with_js/streamable_client.js +++ b/crates/rmcp/tests/test_with_js/streamable_client.js @@ -1,7 +1,8 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/mcp/`)); +const url = process.argv[2] || "http://127.0.0.1:8001/mcp/"; +const transport = new StreamableHTTPClientTransport(new URL(url)); const client = new Client( { diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 5e2c4db6..a4941814 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -88,6 +88,10 @@ path = "src/sampling_stdio.rs" [[example]] name = "servers_counter_sse_actix" path = "src/counter_sse_actix.rs" + +[[example]] +name = "servers_counter_streamable_http_actix" +path = "src/counter_streamable_http_actix.rs" required-features = ["actix-web"] [features] diff --git a/examples/servers/src/counter_streamable_http_actix.rs b/examples/servers/src/counter_streamable_http_actix.rs new file mode 100644 index 00000000..afbfe94f --- /dev/null +++ b/examples/servers/src/counter_streamable_http_actix.rs @@ -0,0 +1,41 @@ +mod common; +use std::sync::Arc; + +use actix_web::{App, HttpServer, middleware}; +use common::counter::Counter; +use rmcp::transport::streamable_http_server::{ + StreamableHttpService, session::local::LocalSessionManager, +}; + +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + let bind_addr = "127.0.0.1:8080"; + + // Create the streamable HTTP service + let service = Arc::new(StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + Default::default(), + )); + + println!("Starting actix-web streamable HTTP server on {}", bind_addr); + println!("POST / - Send JSON-RPC requests"); + println!("GET / - Resume SSE stream with session ID"); + println!("DELETE / - Close session"); + + HttpServer::new(move || { + App::new() + .wrap(middleware::Logger::default()) + .configure(StreamableHttpService::configure(service.clone())) + }) + .bind(bind_addr)? + .run() + .await?; + + Ok(()) +} \ No newline at end of file From 380ec9fc5053ec75668113d3d613a076bd895546 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 18:05:55 +0200 Subject: [PATCH 04/12] refactor: implement explicit type names for multi-framework support - Export framework-specific types with explicit names (AxumSseServer, ActixSseServer, etc.) - Add comprehensive module documentation explaining the type export strategy - Update all tests to use explicit type names with proper feature gates - Remove unused TransportReceiver type aliases to fix warnings - Ensure feature gates are ordered with framework feature last This design provides both explicit control when both frameworks are enabled and maintains backward compatibility with existing code. --- crates/rmcp/src/transport.rs | 61 +++++- .../src/transport/sse_server/actix_impl.rs | 4 +- .../src/transport/sse_server/axum_impl.rs | 1 - crates/rmcp/src/transport/sse_server/mod.rs | 52 ++++- crates/rmcp/tests/test_sse_server.rs | 188 +++++++----------- crates/rmcp/tests/test_with_python.rs | 57 ++++-- 6 files changed, 217 insertions(+), 146 deletions(-) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index ab02abec..17684c25 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -10,7 +10,37 @@ //! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] | //! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] | //! -//!## Helper Transport Types +//! ## Framework Support +//! +//! Several transport types support multiple web frameworks through feature flags: +//! +//! ### SSE Server Transport +//! - **Convenience alias**: [`SseServer`] - resolves to the appropriate implementation based on enabled features +//! - **Explicit types**: [`AxumSseServer`], [`ActixSseServer`] - use specific framework implementations +//! +//! ### Streamable HTTP Server Transport +//! - **Convenience alias**: [`StreamableHttpService`] - resolves to the appropriate implementation based on enabled features +//! - **Explicit types**: [`AxumStreamableHttpService`], [`ActixStreamableHttpService`] - use specific framework implementations +//! +//! #### Type Resolution Strategy +//! The convenience aliases resolve as follows: +//! - When `actix-web` feature is enabled: aliases point to actix-web implementations +//! - When only `axum` feature is enabled: aliases point to axum implementations +//! +//! #### Usage Examples +//! ```rust,ignore +//! // Using convenience aliases (recommended for most cases) +//! use rmcp::transport::{SseServer, StreamableHttpService}; +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! +//! // Using explicit types (when you need a specific implementation) +//! #[cfg(feature = "axum")] +//! use rmcp::transport::AxumSseServer; +//! #[cfg(feature = "axum")] +//! let server = AxumSseServer::serve("127.0.0.1:8080".parse()?).await?; +//! ``` +//! +//! ## Helper Transport Types //! Thers are several helper transport types that can help you to create transport quickly. //! //! ### [Worker Transport](`worker::WorkerTransport`) @@ -105,10 +135,21 @@ pub use sse_client::SseClientTransport; #[cfg(feature = "transport-sse-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] pub mod sse_server; + +// Re-export convenience alias #[cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))))] pub use sse_server::SseServer; +// Re-export explicit types +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", feature = "axum"))))] +pub use sse_server::AxumSseServer; + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", feature = "actix-web"))))] +pub use sse_server::ActixSseServer; + #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] pub mod auth; @@ -122,9 +163,25 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized #[cfg(feature = "transport-streamable-http-server-session")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server-session")))] pub mod streamable_http_server; + +// Re-export configuration #[cfg(feature = "transport-streamable-http-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService}; +pub use streamable_http_server::StreamableHttpServerConfig; + +// Re-export the preferred implementation +#[cfg(all(feature = "transport-streamable-http-server", any(feature = "axum", feature = "actix-web")))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", any(feature = "axum", feature = "actix-web")))))] +pub use streamable_http_server::StreamableHttpService; + +// Re-export explicit types +#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))))] +pub use streamable_http_server::AxumStreamableHttpService; + +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))))] +pub use streamable_http_server::ActixStreamableHttpService; #[cfg(feature = "transport-streamable-http-client")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] diff --git a/crates/rmcp/src/transport/sse_server/actix_impl.rs b/crates/rmcp/src/transport/sse_server/actix_impl.rs index e3fc2722..47037352 100644 --- a/crates/rmcp/src/transport/sse_server/actix_impl.rs +++ b/crates/rmcp/src/transport/sse_server/actix_impl.rs @@ -18,10 +18,10 @@ use crate::{ }; use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; +use crate::transport::common::http_header::HEADER_X_ACCEL_BUFFERING; type TxStore = Arc>>>; -pub type TransportReceiver = ReceiverStream>; #[derive(Clone, Debug)] struct AppData { @@ -169,7 +169,7 @@ async fn sse_handler( Ok(HttpResponse::Ok() .content_type("text/event-stream") .insert_header(("Cache-Control", "no-cache")) - .insert_header(("X-Accel-Buffering", "no")) + .insert_header((HEADER_X_ACCEL_BUFFERING, "no")) .streaming(sse_stream)) } diff --git a/crates/rmcp/src/transport/sse_server/axum_impl.rs b/crates/rmcp/src/transport/sse_server/axum_impl.rs index d3c9e9e3..974e96d7 100644 --- a/crates/rmcp/src/transport/sse_server/axum_impl.rs +++ b/crates/rmcp/src/transport/sse_server/axum_impl.rs @@ -25,7 +25,6 @@ use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_IN type TxStore = Arc>>>; -pub type TransportReceiver = ReceiverStream>; #[derive(Clone)] struct App { diff --git a/crates/rmcp/src/transport/sse_server/mod.rs b/crates/rmcp/src/transport/sse_server/mod.rs index 68c9e29e..206ac3bc 100644 --- a/crates/rmcp/src/transport/sse_server/mod.rs +++ b/crates/rmcp/src/transport/sse_server/mod.rs @@ -1,26 +1,56 @@ +//! SSE Server Transport Module +//! +//! This module provides Server-Sent Events (SSE) transport implementations for MCP. +//! +//! # Type Export Strategy +//! +//! This module exports framework-specific implementations with explicit names: +//! - `AxumSseServer` - The Axum-based SSE server implementation +//! - `ActixSseServer` - The actix-web-based SSE server implementation +//! +//! For convenience, a type alias `SseServer` is provided that resolves to: +//! - `ActixSseServer` when the `actix-web` feature is enabled +//! - `AxumSseServer` when only the `axum` feature is enabled +//! +//! # Examples +//! +//! Using the convenience alias (recommended for most use cases): +//! ```ignore +//! use rmcp::transport::SseServer; +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! ``` +//! +//! Using explicit types (when you need a specific implementation): +//! ```ignore +//! #[cfg(feature = "axum")] +//! use rmcp::transport::AxumSseServer; +//! #[cfg(feature = "axum")] +//! let server = AxumSseServer::serve("127.0.0.1:8080".parse()?).await?; +//! ``` + #[cfg(feature = "transport-sse-server")] pub mod common; -// When only axum is enabled -#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] +// Axum implementation +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] mod axum_impl; -#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] -pub use axum_impl::*; +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +pub use axum_impl::SseServer as AxumSseServer; -// When actix-web is enabled (with or without axum) +// Actix-web implementation #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] mod actix_impl; #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] -pub use actix_impl::*; +pub use actix_impl::SseServer as ActixSseServer; -// When both are enabled, also provide axum implementation under different name -#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "actix-web"))] -pub mod axum_impl; +// Convenience type alias +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +pub type SseServer = ActixSseServer; -#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "actix-web"))] -pub use axum_impl::SseServer as AxumSseServer; +#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] +pub type SseServer = AxumSseServer; // Re-export common types when transport-sse-server is enabled #[cfg(feature = "transport-sse-server")] diff --git a/crates/rmcp/tests/test_sse_server.rs b/crates/rmcp/tests/test_sse_server.rs index 18893ba6..57bb20a1 100644 --- a/crates/rmcp/tests/test_sse_server.rs +++ b/crates/rmcp/tests/test_sse_server.rs @@ -2,8 +2,15 @@ use rmcp::{ ServiceExt, - transport::{SseServer, sse_server::SseServerConfig}, + transport::sse_server::SseServerConfig, }; + +// Import framework-specific types +#[cfg(feature = "axum")] +use rmcp::transport::AxumSseServer; +#[cfg(feature = "actix-web")] +use rmcp::transport::ActixSseServer; + use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -20,28 +27,8 @@ async fn init() { .try_init(); } -#[cfg(all(feature = "transport-sse-server", feature = "axum"))] -#[tokio::test] -async fn test_axum_sse_server_basic() -> anyhow::Result<()> { - init().await; - - let config = SseServerConfig { - bind: "127.0.0.1:0".parse()?, - sse_path: "/sse".to_string(), - post_path: "/message".to_string(), - ct: CancellationToken::new(), - sse_keep_alive: None, - }; - - let ct = config.ct.clone(); - #[cfg(not(feature = "actix-web"))] - let sse_server = SseServer::serve_with_config(config).await?; - #[cfg(feature = "actix-web")] - let sse_server = rmcp::transport::sse_server::AxumSseServer::serve_with_config(config).await?; - let bind_addr = sse_server.config.bind; - - let service_ct = sse_server.with_service(Calculator::default); - +// Common test logic for basic SSE server test +async fn test_sse_server_basic_common(bind_addr: std::net::SocketAddr, ct: CancellationToken, service_ct: CancellationToken) -> anyhow::Result<()> { // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -61,6 +48,27 @@ async fn test_axum_sse_server_basic() -> anyhow::Result<()> { Ok(()) } +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +#[tokio::test] +async fn test_axum_sse_server_basic() -> anyhow::Result<()> { + init().await; + + let config = SseServerConfig { + bind: "127.0.0.1:0".parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let sse_server = AxumSseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let service_ct = sse_server.with_service(Calculator::default); + + test_sse_server_basic_common(bind_addr, ct, service_ct).await +} + #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] #[actix_web::test] async fn test_actix_sse_server_basic() -> anyhow::Result<()> { @@ -75,105 +83,68 @@ async fn test_actix_sse_server_basic() -> anyhow::Result<()> { }; let ct = config.ct.clone(); - let sse_server = SseServer::serve_with_config(config).await?; + let sse_server = ActixSseServer::serve_with_config(config).await?; let bind_addr = sse_server.config.bind; - let service_ct = sse_server.with_service(Calculator::default); - // Give the server a moment to start - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + test_sse_server_basic_common(bind_addr, ct, service_ct).await +} + +// Common client-server integration test logic +#[cfg(feature = "transport-sse-client")] +async fn test_client_server_integration_common(actual_addr: std::net::SocketAddr, ct: CancellationToken) -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; - // Test that server is running by making a request - let client = reqwest::Client::new(); - let response = client - .get(format!("http://{}/sse", bind_addr)) - .header("Accept", "text/event-stream") - .send() - .await?; + let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; + let client = ().serve(transport).await?; - // SSE endpoint should return OK and start streaming - assert_eq!(response.status(), reqwest::StatusCode::OK); + // Test basic operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + client.cancel().await?; ct.cancel(); - service_ct.cancel(); Ok(()) } -#[cfg(all(feature = "transport-sse-server", feature = "axum", feature = "transport-sse-client"))] +#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "axum"))] #[tokio::test] async fn test_axum_client_server_integration() -> anyhow::Result<()> { - use rmcp::transport::SseClientTransport; - init().await; const BIND_ADDRESS: &str = "127.0.0.1:0"; - #[cfg(not(feature = "actix-web"))] - let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; - #[cfg(feature = "actix-web")] - let sse_server = rmcp::transport::sse_server::AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); - let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; - let client = ().serve(transport).await?; - - // Test basic operations - let tools = client.list_all_tools().await?; - assert!(!tools.is_empty()); - assert_eq!(tools.len(), 2); // sum and sub - - client.cancel().await?; - ct.cancel(); - Ok(()) + test_client_server_integration_common(actual_addr, ct).await } -#[cfg(all(feature = "transport-sse-server", feature = "actix-web", feature = "transport-sse-client"))] +#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "actix-web"))] #[actix_web::test] async fn test_actix_client_server_integration() -> anyhow::Result<()> { - use rmcp::transport::SseClientTransport; - init().await; const BIND_ADDRESS: &str = "127.0.0.1:0"; - let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; - let client = ().serve(transport).await?; - - // Test basic operations - let tools = client.list_all_tools().await?; - assert!(!tools.is_empty()); - assert_eq!(tools.len(), 2); // sum and sub - - client.cancel().await?; - ct.cancel(); - Ok(()) + test_client_server_integration_common(actual_addr, ct).await } -#[cfg(all(feature = "transport-sse-server", feature = "axum"))] -#[tokio::test] -async fn test_axum_concurrent_clients() -> anyhow::Result<()> { +// Common concurrent clients test logic +#[cfg(feature = "transport-sse-client")] +async fn test_concurrent_clients_common(actual_addr: std::net::SocketAddr, ct: CancellationToken) -> anyhow::Result<()> { use rmcp::transport::SseClientTransport; - init().await; - - const BIND_ADDRESS: &str = "127.0.0.1:0"; const NUM_CLIENTS: usize = 5; - - #[cfg(not(feature = "actix-web"))] - let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; - #[cfg(feature = "actix-web")] - let sse_server = rmcp::transport::sse_server::AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; - let actual_addr = sse_server.config.bind; - let ct = sse_server.with_service(Calculator::default); - let mut handles = vec![]; for i in 0..NUM_CLIENTS { @@ -203,48 +174,33 @@ async fn test_axum_concurrent_clients() -> anyhow::Result<()> { Ok(()) } -#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "axum"))] +#[tokio::test] +async fn test_axum_concurrent_clients() -> anyhow::Result<()> { + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + test_concurrent_clients_common(actual_addr, ct).await +} + +#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "actix-web"))] #[actix_web::test] async fn test_actix_concurrent_clients() -> anyhow::Result<()> { - use rmcp::transport::SseClientTransport; - init().await; const BIND_ADDRESS: &str = "127.0.0.1:0"; - const NUM_CLIENTS: usize = 5; - let sse_server = SseServer::serve(BIND_ADDRESS.parse()?).await?; + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - let mut handles = vec![]; - - for i in 0..NUM_CLIENTS { - let addr = actual_addr; - let handle = tokio::spawn(async move { - let transport = SseClientTransport::start(format!("http://{}/sse", addr)).await?; - let client = ().serve(transport).await?; - - // Each client does some operations - let tools = client.list_all_tools().await?; - assert!(!tools.is_empty()); - assert_eq!(tools.len(), 2); // sum and sub - - tracing::info!("Client {} completed operations", i); - client.cancel().await?; - Ok::<(), anyhow::Error>(()) - }); - handles.push(handle); - } - - // Wait for all clients to complete - for handle in handles { - handle.await??; - } - - ct.cancel(); - Ok(()) + test_concurrent_clients_common(actual_addr, ct).await } \ No newline at end of file diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 9def1722..ca344e2f 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,8 +1,14 @@ -use axum::Router; use rmcp::{ ServiceExt, - transport::{ConfigureCommandExt, SseServer, TokioChildProcess, sse_server::SseServerConfig}, + transport::{ConfigureCommandExt, TokioChildProcess, sse_server::SseServerConfig}, }; + +// Import framework-specific types +#[cfg(feature = "axum")] +use rmcp::transport::AxumSseServer; +#[cfg(feature = "actix-web")] +use rmcp::transport::ActixSseServer; + use tokio::time::timeout; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -26,20 +32,12 @@ async fn init() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] -async fn test_with_python_client() -> anyhow::Result<()> { - init().await?; - - const BIND_ADDRESS: &str = "127.0.0.1:8000"; - - let ct = SseServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Calculator::default); - +// Common test logic for Python client +async fn test_with_python_client_common(bind_address: &str, ct: CancellationToken) -> anyhow::Result<()> { let status = tokio::process::Command::new("uv") .arg("run") .arg("client.py") - .arg(format!("http://{BIND_ADDRESS}/sse")) + .arg(format!("http://{bind_address}/sse")) .current_dir("tests/test_with_python") .spawn()? .wait() @@ -49,9 +47,40 @@ async fn test_with_python_client() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "axum")] +#[tokio::test] +async fn test_with_python_client_axum() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8000"; + + let ct = AxumSseServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + test_with_python_client_common(BIND_ADDRESS, ct).await +} + +#[cfg(feature = "actix-web")] +#[tokio::test] +async fn test_with_python_client_actix() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8000"; + + let ct = ActixSseServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + test_with_python_client_common(BIND_ADDRESS, ct).await +} + /// Test the SSE server in a nested Axum router. +#[cfg(feature = "axum")] #[tokio::test] async fn test_nested_with_python_client() -> anyhow::Result<()> { + use axum::Router; + init().await?; const BIND_ADDRESS: &str = "127.0.0.1:8001"; @@ -67,7 +96,7 @@ async fn test_nested_with_python_client() -> anyhow::Result<()> { let listener = tokio::net::TcpListener::bind(&sse_config.bind).await?; - let (sse_server, sse_router) = SseServer::new(sse_config); + let (sse_server, sse_router) = AxumSseServer::new(sse_config); let ct = sse_server.with_service(Calculator::default); let main_router = Router::new().nest("/nested", sse_router); From 82d64b7579d821267ecfa0cbba602af363dbc65a Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 18:51:15 +0200 Subject: [PATCH 05/12] fix: actix-web test attribute and add NormalizePath middleware - Use #[actix_web::test] instead of #[tokio::test] for actix-web tests - Add NormalizePath middleware to actix-web SSE server for consistent path handling --- crates/rmcp/src/transport/sse_server/actix_impl.rs | 2 ++ crates/rmcp/tests/test_with_python.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/rmcp/src/transport/sse_server/actix_impl.rs b/crates/rmcp/src/transport/sse_server/actix_impl.rs index 47037352..8ca4fe00 100644 --- a/crates/rmcp/src/transport/sse_server/actix_impl.rs +++ b/crates/rmcp/src/transport/sse_server/actix_impl.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, io, sync::Arc, time::Duration}; use actix_web::{ HttpRequest, HttpResponse, Result, Scope, error::ErrorInternalServerError, + middleware, web::{self, Bytes, Data, Json, Query}, }; use futures::{Sink, SinkExt, Stream, StreamExt}; @@ -278,6 +279,7 @@ impl SseServer { let server = actix_web::HttpServer::new(move || { actix_web::App::new() .app_data(app_data.clone()) + .wrap(middleware::NormalizePath::trim()) .route(&sse_path, web::get().to(sse_handler)) .route(&post_path, web::post().to(post_event_handler)) }) diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index ca344e2f..44c2df14 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -62,7 +62,7 @@ async fn test_with_python_client_axum() -> anyhow::Result<()> { } #[cfg(feature = "actix-web")] -#[tokio::test] +#[actix_web::test] async fn test_with_python_client_actix() -> anyhow::Result<()> { init().await?; From e247791cc4f07a4d76424bbf713b927dd916b385 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 18:56:27 +0200 Subject: [PATCH 06/12] fix: add feature gate for STREAMABLE_HTTP_ACTIX_BIND_ADDRESS constant Only define the constant when actix-web feature is enabled to avoid dead code warning --- crates/rmcp/tests/test_with_js.rs | 1 + examples/servers/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index c9faa4fd..9c17a5f8 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -22,6 +22,7 @@ use common::calculator::Calculator; const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; +#[cfg(feature = "actix-web")] const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8002"; const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8003"; diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index a4941814..8341b70a 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -88,6 +88,7 @@ path = "src/sampling_stdio.rs" [[example]] name = "servers_counter_sse_actix" path = "src/counter_sse_actix.rs" +required-features = ["actix-web"] [[example]] name = "servers_counter_streamable_http_actix" From a8c42224e7e36d020354c0362e157f3692dda737 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 19:11:00 +0200 Subject: [PATCH 07/12] fix: change actix-web test port to avoid conflict with JS server The JavaScript streamable server is hardcoded to use port 8002, which was conflicting with our actix-web test. Changed actix-web test to use port 8004. --- crates/rmcp/tests/test_with_js.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 9c17a5f8..01d246c3 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -23,7 +23,7 @@ use common::calculator::Calculator; const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; #[cfg(feature = "actix-web")] -const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8002"; +const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8004"; const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8003"; #[tokio::test] From 62455b97b2773cde78d7aaca605057fbc40714e5 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 19:18:48 +0200 Subject: [PATCH 08/12] fix: restore original JS server port to avoid breaking existing test Reverted STREAMABLE_HTTP_JS_BIND_ADDRESS back to port 8002 to match the hardcoded port in streamable_server.js. Our actix-web test uses port 8004 to avoid conflicts. --- crates/rmcp/tests/test_with_js.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 01d246c3..8387f2af 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -22,9 +22,9 @@ use common::calculator::Calculator; const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; +const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8002"; #[cfg(feature = "actix-web")] const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8004"; -const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8003"; #[tokio::test] async fn test_with_js_client() -> anyhow::Result<()> { From aa7ea2524e0bafd258238003ebdb3ed9b0e6c287 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 19:19:58 +0200 Subject: [PATCH 09/12] fix: use dep: prefix for rmcp/actix-web feature dependency The feature definition in examples/servers/Cargo.toml should use 'dep:rmcp/actix-web' to correctly enable the rmcp dependency feature. --- examples/servers/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 8341b70a..85bcee11 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -96,4 +96,4 @@ path = "src/counter_streamable_http_actix.rs" required-features = ["actix-web"] [features] -actix-web = ["dep:actix-web", "dep:actix-rt", "rmcp/actix-web"] +actix-web = ["dep:actix-web", "dep:actix-rt", "dep:rmcp/actix-web"] From bfa622b6849451db3177b785450265421dada6b1 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 19:38:25 +0200 Subject: [PATCH 10/12] refactor: use module-based organization for framework-specific transport types - Remove type prefixes (AxumSseServer, ActixSseServer, etc.) - Organize implementations in framework-specific submodules (axum, actix_web) - Each submodule exports types with consistent names (SseServer, StreamableHttpService) - Update all imports to use new module paths - Maintain backward compatibility with convenience aliases at module roots This provides a cleaner API with better module organization and makes it easier to add new framework implementations in the future. --- crates/rmcp/src/transport.rs | 32 ++++++---------- .../{actix_impl.rs => actix_web.rs} | 0 .../sse_server/{axum_impl.rs => axum.rs} | 0 crates/rmcp/src/transport/sse_server/mod.rs | 36 ++++++++---------- .../src/transport/streamable_http_server.rs | 38 +++++++++---------- .../{actix_impl.rs => actix_web.rs} | 0 .../{tower.rs => axum.rs} | 0 crates/rmcp/tests/test_sse_server.rs | 4 +- crates/rmcp/tests/test_with_js.rs | 4 +- crates/rmcp/tests/test_with_python.rs | 4 +- examples/servers/Cargo.toml | 2 +- 11 files changed, 51 insertions(+), 69 deletions(-) rename crates/rmcp/src/transport/sse_server/{actix_impl.rs => actix_web.rs} (100%) rename crates/rmcp/src/transport/sse_server/{axum_impl.rs => axum.rs} (100%) rename crates/rmcp/src/transport/streamable_http_server/{actix_impl.rs => actix_web.rs} (100%) rename crates/rmcp/src/transport/streamable_http_server/{tower.rs => axum.rs} (100%) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 17684c25..6e2f4bd4 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -16,11 +16,11 @@ //! //! ### SSE Server Transport //! - **Convenience alias**: [`SseServer`] - resolves to the appropriate implementation based on enabled features -//! - **Explicit types**: [`AxumSseServer`], [`ActixSseServer`] - use specific framework implementations +//! - **Framework-specific**: Available via `sse_server::axum::SseServer` and `sse_server::actix_web::SseServer` //! //! ### Streamable HTTP Server Transport //! - **Convenience alias**: [`StreamableHttpService`] - resolves to the appropriate implementation based on enabled features -//! - **Explicit types**: [`AxumStreamableHttpService`], [`ActixStreamableHttpService`] - use specific framework implementations +//! - **Framework-specific**: Available via `streamable_http_server::axum::StreamableHttpService` and `streamable_http_server::actix_web::StreamableHttpService` //! //! #### Type Resolution Strategy //! The convenience aliases resolve as follows: @@ -33,11 +33,11 @@ //! use rmcp::transport::{SseServer, StreamableHttpService}; //! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; //! -//! // Using explicit types (when you need a specific implementation) +//! // Using framework-specific modules (when you need a specific implementation) //! #[cfg(feature = "axum")] -//! use rmcp::transport::AxumSseServer; +//! use rmcp::transport::sse_server::axum::SseServer; //! #[cfg(feature = "axum")] -//! let server = AxumSseServer::serve("127.0.0.1:8080".parse()?).await?; +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; //! ``` //! //! ## Helper Transport Types @@ -141,14 +141,9 @@ pub mod sse_server; #[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))))] pub use sse_server::SseServer; -// Re-export explicit types -#[cfg(all(feature = "transport-sse-server", feature = "axum"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", feature = "axum"))))] -pub use sse_server::AxumSseServer; - -#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", feature = "actix-web"))))] -pub use sse_server::ActixSseServer; +// Framework-specific implementations are available via submodules: +// - sse_server::axum::SseServer +// - sse_server::actix_web::SseServer #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] @@ -174,14 +169,9 @@ pub use streamable_http_server::StreamableHttpServerConfig; #[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", any(feature = "axum", feature = "actix-web")))))] pub use streamable_http_server::StreamableHttpService; -// Re-export explicit types -#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))))] -pub use streamable_http_server::AxumStreamableHttpService; - -#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))))] -pub use streamable_http_server::ActixStreamableHttpService; +// Framework-specific implementations are available via submodules: +// - streamable_http_server::axum::StreamableHttpService +// - streamable_http_server::actix_web::StreamableHttpService #[cfg(feature = "transport-streamable-http-client")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] diff --git a/crates/rmcp/src/transport/sse_server/actix_impl.rs b/crates/rmcp/src/transport/sse_server/actix_web.rs similarity index 100% rename from crates/rmcp/src/transport/sse_server/actix_impl.rs rename to crates/rmcp/src/transport/sse_server/actix_web.rs diff --git a/crates/rmcp/src/transport/sse_server/axum_impl.rs b/crates/rmcp/src/transport/sse_server/axum.rs similarity index 100% rename from crates/rmcp/src/transport/sse_server/axum_impl.rs rename to crates/rmcp/src/transport/sse_server/axum.rs diff --git a/crates/rmcp/src/transport/sse_server/mod.rs b/crates/rmcp/src/transport/sse_server/mod.rs index 206ac3bc..3a2e07aa 100644 --- a/crates/rmcp/src/transport/sse_server/mod.rs +++ b/crates/rmcp/src/transport/sse_server/mod.rs @@ -2,15 +2,17 @@ //! //! This module provides Server-Sent Events (SSE) transport implementations for MCP. //! -//! # Type Export Strategy +//! # Module Organization //! -//! This module exports framework-specific implementations with explicit names: -//! - `AxumSseServer` - The Axum-based SSE server implementation -//! - `ActixSseServer` - The actix-web-based SSE server implementation +//! Framework-specific implementations are organized in submodules: +//! - `axum` - Contains the Axum-based SSE server implementation +//! - `actix_web` - Contains the actix-web-based SSE server implementation //! -//! For convenience, a type alias `SseServer` is provided that resolves to: -//! - `ActixSseServer` when the `actix-web` feature is enabled -//! - `AxumSseServer` when only the `axum` feature is enabled +//! Each submodule exports a `SseServer` type with the same interface. +//! +//! For convenience, a type alias `SseServer` is provided at the module root that resolves to: +//! - `actix_web::SseServer` when the `actix-web` feature is enabled +//! - `axum::SseServer` when only the `axum` feature is enabled //! //! # Examples //! @@ -20,12 +22,12 @@ //! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; //! ``` //! -//! Using explicit types (when you need a specific implementation): +//! Using framework-specific modules (when you need a specific implementation): //! ```ignore //! #[cfg(feature = "axum")] -//! use rmcp::transport::AxumSseServer; +//! use rmcp::transport::sse_server::axum::SseServer; //! #[cfg(feature = "axum")] -//! let server = AxumSseServer::serve("127.0.0.1:8080".parse()?).await?; +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; //! ``` #[cfg(feature = "transport-sse-server")] @@ -33,24 +35,18 @@ pub mod common; // Axum implementation #[cfg(all(feature = "transport-sse-server", feature = "axum"))] -mod axum_impl; - -#[cfg(all(feature = "transport-sse-server", feature = "axum"))] -pub use axum_impl::SseServer as AxumSseServer; +pub mod axum; // Actix-web implementation #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] -mod actix_impl; - -#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] -pub use actix_impl::SseServer as ActixSseServer; +pub mod actix_web; // Convenience type alias #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] -pub type SseServer = ActixSseServer; +pub use actix_web::SseServer; #[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] -pub type SseServer = AxumSseServer; +pub use axum::SseServer; // Re-export common types when transport-sse-server is enabled #[cfg(feature = "transport-sse-server")] diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index 12cfd173..c6046e7f 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -2,15 +2,17 @@ //! //! This module provides streamable HTTP transport implementations for MCP. //! -//! # Type Export Strategy +//! # Module Organization //! -//! This module exports framework-specific implementations with explicit names: -//! - `AxumStreamableHttpService` - The Axum-based streamable HTTP service implementation -//! - `ActixStreamableHttpService` - The actix-web-based streamable HTTP service implementation +//! Framework-specific implementations are organized in submodules: +//! - `axum` - Contains the Axum-based streamable HTTP service implementation +//! - `actix_web` - Contains the actix-web-based streamable HTTP service implementation //! -//! For convenience, a type alias `StreamableHttpService` is provided that resolves to: -//! - `ActixStreamableHttpService` when the `actix-web` feature is enabled -//! - `AxumStreamableHttpService` when only the `axum` feature is enabled +//! Each submodule exports a `StreamableHttpService` type with the same interface. +//! +//! For convenience, a type alias `StreamableHttpService` is provided at the module root that resolves to: +//! - `actix_web::StreamableHttpService` when the `actix-web` feature is enabled +//! - `axum::StreamableHttpService` when only the `axum` feature is enabled //! //! # Examples //! @@ -20,12 +22,12 @@ //! let service = StreamableHttpService::new(|| Ok(handler), session_manager, config); //! ``` //! -//! Using explicit types (when you need a specific implementation): +//! Using framework-specific modules (when you need a specific implementation): //! ```ignore //! #[cfg(feature = "axum")] -//! use rmcp::transport::AxumStreamableHttpService; +//! use rmcp::transport::streamable_http_server::axum::StreamableHttpService; //! #[cfg(feature = "axum")] -//! let service = AxumStreamableHttpService::new(|| Ok(handler), session_manager, config); +//! let service = StreamableHttpService::new(|| Ok(handler), session_manager, config); //! ``` pub mod session; @@ -53,24 +55,18 @@ impl Default for StreamableHttpServerConfig { // Axum implementation #[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))))] -pub mod tower; - -#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] -pub use tower::StreamableHttpService as AxumStreamableHttpService; +pub mod axum; // Actix-web implementation #[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))))] -pub mod actix_impl; - -#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] -pub use actix_impl::StreamableHttpService as ActixStreamableHttpService; +pub mod actix_web; // Export the preferred implementation as StreamableHttpService (without generic parameters) #[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] -pub use actix_impl::StreamableHttpService; +pub use actix_web::StreamableHttpService; #[cfg(all(feature = "transport-streamable-http-server", feature = "axum", not(feature = "actix-web")))] -pub use tower::StreamableHttpService; +pub use axum::StreamableHttpService; -pub use session::{SessionId, SessionManager}; +pub use session::{SessionId, SessionManager}; \ No newline at end of file diff --git a/crates/rmcp/src/transport/streamable_http_server/actix_impl.rs b/crates/rmcp/src/transport/streamable_http_server/actix_web.rs similarity index 100% rename from crates/rmcp/src/transport/streamable_http_server/actix_impl.rs rename to crates/rmcp/src/transport/streamable_http_server/actix_web.rs diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs similarity index 100% rename from crates/rmcp/src/transport/streamable_http_server/tower.rs rename to crates/rmcp/src/transport/streamable_http_server/axum.rs diff --git a/crates/rmcp/tests/test_sse_server.rs b/crates/rmcp/tests/test_sse_server.rs index 57bb20a1..20088d11 100644 --- a/crates/rmcp/tests/test_sse_server.rs +++ b/crates/rmcp/tests/test_sse_server.rs @@ -7,9 +7,9 @@ use rmcp::{ // Import framework-specific types #[cfg(feature = "axum")] -use rmcp::transport::AxumSseServer; +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; #[cfg(feature = "actix-web")] -use rmcp::transport::ActixSseServer; +use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 8387f2af..39a0e7ce 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -12,9 +12,9 @@ use rmcp::{ // Import framework-specific types #[cfg(feature = "axum")] -use rmcp::transport::AxumStreamableHttpService; +use rmcp::transport::streamable_http_server::axum::StreamableHttpService as AxumStreamableHttpService; #[cfg(feature = "actix-web")] -use rmcp::transport::ActixStreamableHttpService; +use rmcp::transport::streamable_http_server::actix_web::StreamableHttpService as ActixStreamableHttpService; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 44c2df14..c036c6cf 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -5,9 +5,9 @@ use rmcp::{ // Import framework-specific types #[cfg(feature = "axum")] -use rmcp::transport::AxumSseServer; +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; #[cfg(feature = "actix-web")] -use rmcp::transport::ActixSseServer; +use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; use tokio::time::timeout; use tokio_util::sync::CancellationToken; diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 85bcee11..8341b70a 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -96,4 +96,4 @@ path = "src/counter_streamable_http_actix.rs" required-features = ["actix-web"] [features] -actix-web = ["dep:actix-web", "dep:actix-rt", "dep:rmcp/actix-web"] +actix-web = ["dep:actix-web", "dep:actix-rt", "rmcp/actix-web"] From 4cba1da8185fd0a51062275807248464da007116 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 19:58:27 +0200 Subject: [PATCH 11/12] docs: add comprehensive documentation for actix-web feature - Document actix-web as an alternative web framework in README files - Add feature flag documentation explaining axum (default) vs actix-web - Include examples showing both framework usage patterns - Document precedence behavior when both features are enabled - Add explanatory comments to actix-web example files - Note runtime differences (#[actix_web::main] vs #[tokio::main]) --- README.md | 11 ++++ crates/rmcp/README.md | 59 ++++++++++++++++++- crates/rmcp/src/lib.rs | 17 ++++++ examples/servers/src/counter_sse_actix.rs | 6 ++ .../src/counter_streamable_http_actix.rs | 7 +++ 5 files changed, 97 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3cb91ceb..3c822c44 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,17 @@ rmcp = { version = "0.1", features = ["server"] } ## or dev channel rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main" } ``` + +#### Web Framework Choice +By default, rmcp uses [axum](https://github.com/tokio-rs/axum) for web server transports (SSE and streamable HTTP). You can also use [actix-web](https://github.com/actix/actix-web) as an alternative: + +```toml +# For actix-web support +rmcp = { version = "0.1", features = ["server", "actix-web"] } +``` + +**Note**: When both `axum` and `actix-web` features are enabled, actix-web takes precedence for convenience type aliases. + ### Third Dependencies Basic dependencies: - [tokio required](https://github.com/tokio-rs/tokio) diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 7c93dd24..38f2ce25 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -195,6 +195,9 @@ RMCP uses feature flags to control which components are included: - `client`: Enable client functionality - `server`: Enable server functionality and the tool system - `macros`: Enable the `#[tool]` macro (enabled by default) +- Web framework features: + - `axum`: Axum web framework support (enabled by default) + - `actix-web`: actix-web framework support as an alternative to axum - Transport-specific features: - `transport-async-rw`: Async read/write support - `transport-io`: I/O stream support @@ -204,15 +207,65 @@ RMCP uses feature flags to control which components are included: - `auth`: OAuth2 authentication support - `schemars`: JSON Schema generation (for tool definitions) +**Note**: When both `axum` and `actix-web` features are enabled, actix-web implementations take precedence for convenience type aliases. + + +## Web Framework Support + +SSE and streamable HTTP server transports support multiple web frameworks: + +### Using axum (default) +```rust +use rmcp::transport::SseServer; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; + let ct = server.with_service(|| Ok(MyService::new())); + ct.cancelled().await; + Ok(()) +} +``` + +### Using actix-web +Enable the `actix-web` feature in your `Cargo.toml`: +```toml +rmcp = { version = "0.1", features = ["server", "actix-web"] } +``` + +Then use with actix-web runtime: +```rust +use rmcp::transport::SseServer; + +#[actix_web::main] +async fn main() -> Result<(), Box> { + let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; + let ct = server.with_service(|| Ok(MyService::new())); + ct.cancelled().await; + Ok(()) +} +``` + +### Framework-specific imports +When you need to use a specific framework implementation: +```rust +// For axum +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer; + +// For actix-web +#[cfg(feature = "actix-web")] +use rmcp::transport::sse_server::actix_web::SseServer; +``` ## Transports - `transport-io`: Server stdio transport -- `transport-sse-server`: Server SSE transport +- `transport-sse-server`: Server SSE transport (supports both axum and actix-web) - `transport-child-process`: Client stdio transport - `transport-sse-client`: Client sse transport -- `transport-streamable-http-server` streamable http server transport -- `transport-streamable-http-client` streamable http client transport +- `transport-streamable-http-server`: Streamable HTTP server transport (supports both axum and actix-web) +- `transport-streamable-http-client`: Streamable HTTP client transport
Transport diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index a3d1a414..cbd62499 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -51,6 +51,23 @@ //! Next also implement [ServerHandler] for `Counter` and start the server inside //! `main` by calling `Counter::new().serve(...)`. See the examples directory in the repository for more information. //! +//! ### Web Framework Support +//! +//! Server transports (SSE and streamable HTTP) support both axum (default) and actix-web: +//! +//! ```rust,ignore +//! // Using actix-web (requires actix-web feature) +//! use rmcp::{ServiceExt, transport::SseServer}; +//! +//! #[actix_web::main] +//! async fn main() -> Result<(), Box> { +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! let ct = server.with_service(Counter::new); +//! ct.cancelled().await; +//! Ok(()) +//! } +//! ``` +//! //! ## Client //! //! A client can be used to interact with a server. Clients can be used to get a diff --git a/examples/servers/src/counter_sse_actix.rs b/examples/servers/src/counter_sse_actix.rs index 532e9cb1..f1c6af86 100644 --- a/examples/servers/src/counter_sse_actix.rs +++ b/examples/servers/src/counter_sse_actix.rs @@ -1,3 +1,5 @@ +// Example of using SSE server transport with actix-web framework +// This requires the "actix-web" feature to be enabled in Cargo.toml use rmcp::transport::sse_server::{SseServer, SseServerConfig}; use tracing_subscriber::{ layer::SubscriberExt, @@ -9,6 +11,8 @@ use common::counter::Counter; const BIND_ADDRESS: &str = "127.0.0.1:8000"; +// Note: Using #[actix_web::main] instead of #[tokio::main] +// This sets up the actix-web runtime which is required for actix-web transports #[actix_web::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::registry() @@ -29,6 +33,8 @@ async fn main() -> anyhow::Result<()> { let ct_signal = config.ct.clone(); + // When actix-web feature is enabled, SseServer uses actix-web implementation + // The same API works with both axum and actix-web let sse_server = SseServer::serve_with_config(config).await?; let bind_addr = sse_server.config.bind; let ct = sse_server.with_service(Counter::new); diff --git a/examples/servers/src/counter_streamable_http_actix.rs b/examples/servers/src/counter_streamable_http_actix.rs index afbfe94f..1fa57d9a 100644 --- a/examples/servers/src/counter_streamable_http_actix.rs +++ b/examples/servers/src/counter_streamable_http_actix.rs @@ -1,3 +1,5 @@ +// Example of using streamable HTTP server transport with actix-web framework +// This requires the "actix-web" feature to be enabled in Cargo.toml mod common; use std::sync::Arc; @@ -7,6 +9,8 @@ use rmcp::transport::streamable_http_server::{ StreamableHttpService, session::local::LocalSessionManager, }; +// Note: Using #[actix_web::main] instead of #[tokio::main] +// This sets up the actix-web runtime which is required for actix-web transports #[actix_web::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() @@ -17,6 +21,7 @@ async fn main() -> anyhow::Result<()> { let bind_addr = "127.0.0.1:8080"; // Create the streamable HTTP service + // When actix-web feature is enabled, StreamableHttpService uses actix-web implementation let service = Arc::new(StreamableHttpService::new( || Ok(Counter::new()), LocalSessionManager::default().into(), @@ -28,6 +33,8 @@ async fn main() -> anyhow::Result<()> { println!("GET / - Resume SSE stream with session ID"); println!("DELETE / - Close session"); + // Use actix-web's HttpServer and App to host the service + // The StreamableHttpService::configure method sets up the routes HttpServer::new(move || { App::new() .wrap(middleware::Logger::default()) From 9f02c892e3fdddaf055da65e083e9558f072e175 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Tue, 1 Jul 2025 20:43:28 +0200 Subject: [PATCH 12/12] style: apply cargo fmt Apply Rust formatting standards to ensure consistent code style across the codebase. --- crates/rmcp/src/transport.rs | 55 +---- .../src/transport/sse_server/actix_web.rs | 192 +++++++++--------- crates/rmcp/src/transport/sse_server/axum.rs | 120 ++++++----- .../rmcp/src/transport/sse_server/common.rs | 3 +- crates/rmcp/src/transport/sse_server/mod.rs | 10 +- .../src/transport/streamable_http_server.rs | 20 +- .../streamable_http_server/actix_web.rs | 168 ++++++++------- crates/rmcp/tests/common/calculator.rs | 2 +- crates/rmcp/tests/test_sse_server.rs | 117 ++++++----- crates/rmcp/tests/test_with_js.rs | 51 ++--- crates/rmcp/tests/test_with_python.rs | 19 +- examples/servers/src/complex_auth_sse.rs | 3 +- .../src/counter_hyper_streamable_http.rs | 3 +- examples/servers/src/counter_sse.rs | 37 ++-- examples/servers/src/counter_sse_actix.rs | 9 +- .../src/counter_streamable_http_actix.rs | 4 +- examples/servers/src/counter_streamhttp.rs | 3 +- examples/servers/src/simple_auth_sse.rs | 2 +- 18 files changed, 412 insertions(+), 406 deletions(-) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 6e2f4bd4..2cf94012 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -10,37 +10,7 @@ //! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] | //! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] | //! -//! ## Framework Support -//! -//! Several transport types support multiple web frameworks through feature flags: -//! -//! ### SSE Server Transport -//! - **Convenience alias**: [`SseServer`] - resolves to the appropriate implementation based on enabled features -//! - **Framework-specific**: Available via `sse_server::axum::SseServer` and `sse_server::actix_web::SseServer` -//! -//! ### Streamable HTTP Server Transport -//! - **Convenience alias**: [`StreamableHttpService`] - resolves to the appropriate implementation based on enabled features -//! - **Framework-specific**: Available via `streamable_http_server::axum::StreamableHttpService` and `streamable_http_server::actix_web::StreamableHttpService` -//! -//! #### Type Resolution Strategy -//! The convenience aliases resolve as follows: -//! - When `actix-web` feature is enabled: aliases point to actix-web implementations -//! - When only `axum` feature is enabled: aliases point to axum implementations -//! -//! #### Usage Examples -//! ```rust,ignore -//! // Using convenience aliases (recommended for most cases) -//! use rmcp::transport::{SseServer, StreamableHttpService}; -//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; -//! -//! // Using framework-specific modules (when you need a specific implementation) -//! #[cfg(feature = "axum")] -//! use rmcp::transport::sse_server::axum::SseServer; -//! #[cfg(feature = "axum")] -//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; -//! ``` -//! -//! ## Helper Transport Types +//!## Helper Transport Types //! Thers are several helper transport types that can help you to create transport quickly. //! //! ### [Worker Transport](`worker::WorkerTransport`) @@ -135,16 +105,10 @@ pub use sse_client::SseClientTransport; #[cfg(feature = "transport-sse-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] pub mod sse_server; - -// Re-export convenience alias -#[cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-sse-server", any(feature = "axum", feature = "actix-web")))))] +#[cfg(feature = "transport-sse-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-server")))] pub use sse_server::SseServer; -// Framework-specific implementations are available via submodules: -// - sse_server::axum::SseServer -// - sse_server::actix_web::SseServer - #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] pub mod auth; @@ -158,20 +122,9 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized #[cfg(feature = "transport-streamable-http-server-session")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server-session")))] pub mod streamable_http_server; - -// Re-export configuration #[cfg(feature = "transport-streamable-http-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use streamable_http_server::StreamableHttpServerConfig; - -// Re-export the preferred implementation -#[cfg(all(feature = "transport-streamable-http-server", any(feature = "axum", feature = "actix-web")))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", any(feature = "axum", feature = "actix-web")))))] -pub use streamable_http_server::StreamableHttpService; - -// Framework-specific implementations are available via submodules: -// - streamable_http_server::axum::StreamableHttpService -// - streamable_http_server::actix_web::StreamableHttpService +pub use streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService}; #[cfg(feature = "transport-streamable-http-client")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] diff --git a/crates/rmcp/src/transport/sse_server/actix_web.rs b/crates/rmcp/src/transport/sse_server/actix_web.rs index 8ca4fe00..6fbab232 100644 --- a/crates/rmcp/src/transport/sse_server/actix_web.rs +++ b/crates/rmcp/src/transport/sse_server/actix_web.rs @@ -12,15 +12,14 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::{CancellationToken, PollSender}; use tracing::Instrument; +use super::common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, + transport::common::http_header::HEADER_X_ACCEL_BUFFERING, }; -use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; -use crate::transport::common::http_header::HEADER_X_ACCEL_BUFFERING; - type TxStore = Arc>>>; @@ -67,41 +66,39 @@ async fn post_event_handler( ) -> Result { let session_id = &query.session_id; tracing::debug!(session_id, ?message, "new client message"); - + let tx = { let rg = app_data.txs.read().await; rg.get(session_id.as_str()) .ok_or_else(|| actix_web::error::ErrorNotFound("Session not found"))? .clone() }; - + // Note: In actix-web, we don't have direct access to modify extensions // This would need a different approach for passing HTTP request context - + if tx.send(message.0).await.is_err() { tracing::error!("send message error"); return Err(actix_web::error::ErrorGone("Session closed")); } - + Ok(HttpResponse::Accepted().finish()) } -async fn sse_handler( - app_data: Data, - _req: HttpRequest, -) -> Result { +async fn sse_handler(app_data: Data, _req: HttpRequest) -> Result { let session = session_id(); tracing::info!(%session, "sse connection"); - + let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64); let to_client_tx_clone = to_client_tx.clone(); - app_data.txs + app_data + .txs .write() .await .insert(session.clone(), from_client_tx); - + let _session_id = session.clone(); let stream = ReceiverStream::new(from_client_rx); let sink = PollSender::new(to_client_tx); @@ -111,17 +108,19 @@ async fn sse_handler( session_id: session.clone(), tx_store: app_data.txs.clone(), }; - + let transport_send_result = app_data.transport_tx.send(transport); if transport_send_result.is_err() { tracing::warn!("send transport out error"); - return Err(ErrorInternalServerError("Failed to send transport, server is closed")); + return Err(ErrorInternalServerError( + "Failed to send transport, server is closed", + )); } - + let post_path = app_data.post_path.clone(); let ping_interval = app_data.sse_ping_interval; let session_for_stream = session.clone(); - + // Create SSE response stream let sse_stream = async_stream::stream! { // Send initial endpoint message @@ -129,13 +128,13 @@ async fn sse_handler( "event: endpoint\ndata: {}?sessionId={}\n\n", post_path, session_for_stream ))); - + // Set up ping interval let mut ping_interval = tokio::time::interval(ping_interval); ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - + let mut rx = ReceiverStream::new(to_client_rx); - + loop { tokio::select! { Some(message) = rx.next() => { @@ -155,18 +154,18 @@ async fn sse_handler( } } }; - + // Clean up on disconnect let app_data_clone = app_data.clone(); let session_for_cleanup = session.clone(); actix_rt::spawn(async move { to_client_tx_clone.closed().await; - + let mut txs = app_data_clone.txs.write().await; txs.remove(&session_for_cleanup); tracing::debug!(%session_for_cleanup, "Closed session and cleaned up resources"); }); - + Ok(HttpResponse::Ok() .content_type("text/event-stream") .insert_header(("Cache-Control", "no-cache")) @@ -259,23 +258,23 @@ impl SseServer { }) .await } - + pub async fn serve_with_config(mut config: SseServerConfig) -> io::Result { let bind_addr = config.bind; let ct = config.ct.clone(); - + // First bind to get the actual address let listener = std::net::TcpListener::bind(bind_addr)?; let actual_addr = listener.local_addr()?; listener.set_nonblocking(true)?; - + // Update config with actual address config.bind = actual_addr; let (sse_server, _) = Self::new(config); let app_data = sse_server.app_data.clone(); let sse_path = sse_server.config.sse_path.clone(); let post_path = sse_server.config.post_path.clone(); - + let server = actix_web::HttpServer::new(move || { actix_web::App::new() .app_data(app_data.clone()) @@ -285,16 +284,16 @@ impl SseServer { }) .listen(listener)? .run(); - + let ct_child = ct.child_token(); let server_handle = server.handle(); - + actix_rt::spawn(async move { ct_child.cancelled().await; tracing::info!("sse server cancelled"); server_handle.stop(true).await; }); - + actix_rt::spawn( async move { if let Err(e) = server.await { @@ -303,7 +302,7 @@ impl SseServer { } .instrument(tracing::info_span!("sse-server", bind_address = %actual_addr)), ); - + Ok(sse_server) } @@ -312,17 +311,17 @@ impl SseServer { config.post_path.clone(), config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), ); - + let sse_path = config.sse_path.clone(); let post_path = config.post_path.clone(); - + let app_data = Data::new(app_data); - + let scope = web::scope("") .app_data(app_data.clone()) .route(&sse_path, web::get().to(sse_handler)) .route(&post_path, web::post().to(post_event_handler)); - + let server = SseServer { transport_rx: Arc::new(Mutex::new(transport_rx)), config, @@ -340,7 +339,7 @@ impl SseServer { use crate::service::ServiceExt; let ct = self.config.ct.clone(); let transport_rx = self.transport_rx.clone(); - + actix_rt::spawn(async move { while let Some(transport) = transport_rx.lock().await.recv().await { let service = service_provider(); @@ -366,7 +365,7 @@ impl SseServer { { let ct = self.config.ct.clone(); let transport_rx = self.transport_rx.clone(); - + actix_rt::spawn(async move { while let Some(transport) = transport_rx.lock().await.recv().await { let service = service_provider(); @@ -410,30 +409,32 @@ impl Stream for SseServer { #[cfg(test)] mod tests { - use super::*; use futures::{SinkExt, StreamExt}; use tokio::time::timeout; + use super::*; + #[tokio::test] async fn test_session_management() { - let (app_data, transport_rx) = AppData::new("/message".to_string(), Duration::from_secs(15)); - + let (app_data, transport_rx) = + AppData::new("/message".to_string(), Duration::from_secs(15)); + // Create a session let session_id = session_id(); let (tx, _rx) = tokio::sync::mpsc::channel(64); - + // Insert session app_data.txs.write().await.insert(session_id.clone(), tx); - + // Verify session exists assert!(app_data.txs.read().await.contains_key(&session_id)); - + // Remove session app_data.txs.write().await.remove(&session_id); - + // Verify session removed assert!(!app_data.txs.read().await.contains_key(&session_id)); - + drop(transport_rx); } @@ -446,12 +447,12 @@ mod tests { ct: CancellationToken::new(), sse_keep_alive: Some(Duration::from_secs(15)), }; - + let (sse_server, scope) = SseServer::new(config); - + assert_eq!(sse_server.config.sse_path, "/sse"); assert_eq!(sse_server.config.post_path, "/message"); - + // Scope should be properly configured drop(scope); // Just ensure it's created without panic } @@ -462,59 +463,61 @@ mod tests { let stream = ReceiverStream::new(rx); let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); let sink = PollSender::new(sink_tx); - + let mut transport = SseServerTransport { stream, sink, session_id: session_id(), tx_store: Default::default(), }; - + // Test sending through transport - use crate::model::{ServerResult, EmptyResult, JsonRpcMessage}; - let msg: TxJsonRpcMessage = JsonRpcMessage::Response(crate::model::JsonRpcResponse { - jsonrpc: crate::model::JsonRpcVersion2_0, - id: crate::model::NumberOrString::Number(1), - result: ServerResult::EmptyResult(EmptyResult {}), - }); + use crate::model::{EmptyResult, JsonRpcMessage, ServerResult}; + let msg: TxJsonRpcMessage = + JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); // For PollSender, we need to send through async context transport.send(msg).await.unwrap(); - + // Should receive the message let received = timeout(Duration::from_millis(100), sink_rx.recv()) .await .unwrap() .unwrap(); - + match received { - TxJsonRpcMessage::::Response(_) => {}, + TxJsonRpcMessage::::Response(_) => {} _ => panic!("Unexpected message type"), } - + // Test receiving through transport - let client_msg: RxJsonRpcMessage = crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { - jsonrpc: crate::model::JsonRpcVersion2_0, - notification: crate::model::ClientNotification::CancelledNotification( - crate::model::Notification { - method: crate::model::CancelledNotificationMethod, - params: crate::model::CancelledNotificationParam { - request_id: crate::model::NumberOrString::Number(1), - reason: None, + let client_msg: RxJsonRpcMessage = + crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), }, - extensions: Default::default(), - } - ), - }); + ), + }); tx.send(client_msg).await.unwrap(); drop(tx); - + let received = timeout(Duration::from_millis(100), transport.next()) .await .unwrap() .unwrap(); - + match received { - RxJsonRpcMessage::::Notification(_) => {}, + RxJsonRpcMessage::::Notification(_) => {} _ => panic!("Unexpected message type"), } } @@ -522,14 +525,14 @@ mod tests { #[actix_web::test] async fn test_post_event_handler_session_not_found() { use actix_web::test; - + let (app_data, _) = AppData::new("/message".to_string(), Duration::from_secs(15)); let app_data = Data::new(app_data); - + let query = PostEventQuery { session_id: "non-existent".to_string(), }; - + // Create a simple cancelled notification let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { jsonrpc: crate::model::JsonRpcVersion2_0, @@ -541,17 +544,18 @@ mod tests { reason: None, }, extensions: Default::default(), - } + }, ), }); - + let result = post_event_handler( app_data, Query(query), test::TestRequest::default().to_http_request(), Json(client_msg), - ).await; - + ) + .await; + assert!(result.is_err()); } @@ -564,42 +568,44 @@ mod tests { ct: CancellationToken::new(), sse_keep_alive: None, }; - + let ct = config.ct.clone(); let (sse_server, _) = SseServer::new(config); - + // Test that the cancellation token is properly connected assert!(!ct.is_cancelled()); ct.cancel(); assert!(ct.is_cancelled()); - + // Verify server config assert!(sse_server.config.ct.is_cancelled()); } #[actix_web::test] async fn test_sse_stream_generation() { - let (app_data, mut transport_rx) = AppData::new("/message".to_string(), Duration::from_secs(15)); + let (app_data, mut transport_rx) = + AppData::new("/message".to_string(), Duration::from_secs(15)); let app_data = Data::new(app_data); - + // Call SSE handler let result = sse_handler( app_data.clone(), actix_web::test::TestRequest::default().to_http_request(), - ).await; - + ) + .await; + assert!(result.is_ok()); let response = result.unwrap(); - + // Check response headers assert_eq!(response.status(), actix_web::http::StatusCode::OK); assert_eq!( response.headers().get("content-type").unwrap(), "text/event-stream" ); - + // Verify a transport was created let transport = transport_rx.try_recv(); assert!(transport.is_ok()); } -} \ No newline at end of file +} diff --git a/crates/rmcp/src/transport/sse_server/axum.rs b/crates/rmcp/src/transport/sse_server/axum.rs index 974e96d7..f24a9fd5 100644 --- a/crates/rmcp/src/transport/sse_server/axum.rs +++ b/crates/rmcp/src/transport/sse_server/axum.rs @@ -15,14 +15,13 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::{CancellationToken, PollSender}; use tracing::Instrument; +use super::common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, }; -use super::common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; - type TxStore = Arc>>>; @@ -214,7 +213,6 @@ impl Stream for SseServerTransport { } } - #[derive(Debug)] pub struct SseServer { transport_rx: tokio::sync::mpsc::UnboundedReceiver, @@ -338,30 +336,31 @@ impl Stream for SseServer { #[cfg(test)] mod tests { - use super::*; use futures::{SinkExt, StreamExt}; use tokio::time::timeout; + use super::*; + #[tokio::test] async fn test_session_management() { let (app, transport_rx) = App::new("/message".to_string(), Duration::from_secs(15)); - + // Create a session let session_id = session_id(); let (tx, _rx) = tokio::sync::mpsc::channel(64); - + // Insert session app.txs.write().await.insert(session_id.clone(), tx); - + // Verify session exists assert!(app.txs.read().await.contains_key(&session_id)); - + // Remove session app.txs.write().await.remove(&session_id); - + // Verify session removed assert!(!app.txs.read().await.contains_key(&session_id)); - + drop(transport_rx); } @@ -374,12 +373,12 @@ mod tests { ct: CancellationToken::new(), sse_keep_alive: Some(Duration::from_secs(15)), }; - + let (sse_server, router) = SseServer::new(config); - + assert_eq!(sse_server.config.sse_path, "/sse"); assert_eq!(sse_server.config.post_path, "/message"); - + // Router should be properly configured drop(router); // Just ensure it's created without panic } @@ -390,75 +389,79 @@ mod tests { let stream = ReceiverStream::new(rx); let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); let sink = PollSender::new(sink_tx); - + let mut transport = SseServerTransport { stream, sink, session_id: session_id(), tx_store: Default::default(), }; - + // Test sending through transport - use crate::model::{ServerResult, EmptyResult, JsonRpcMessage}; - let msg: TxJsonRpcMessage = JsonRpcMessage::Response(crate::model::JsonRpcResponse { - jsonrpc: crate::model::JsonRpcVersion2_0, - id: crate::model::NumberOrString::Number(1), - result: ServerResult::EmptyResult(EmptyResult {}), - }); + use crate::model::{EmptyResult, JsonRpcMessage, ServerResult}; + let msg: TxJsonRpcMessage = + JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); // For PollSender, we need to send through async context transport.send(msg).await.unwrap(); - + // Should receive the message let received = timeout(Duration::from_millis(100), sink_rx.recv()) .await .unwrap() .unwrap(); - + match received { - TxJsonRpcMessage::::Response(_) => {}, + TxJsonRpcMessage::::Response(_) => {} _ => panic!("Unexpected message type"), } - - // Test receiving through transport - let client_msg: RxJsonRpcMessage = crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { - jsonrpc: crate::model::JsonRpcVersion2_0, - notification: crate::model::ClientNotification::CancelledNotification( - crate::model::Notification { - method: crate::model::CancelledNotificationMethod, - params: crate::model::CancelledNotificationParam { - request_id: crate::model::NumberOrString::Number(1), - reason: None, + + // Test receiving through transport + let client_msg: RxJsonRpcMessage = + crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), }, - extensions: Default::default(), - } - ), - }); + ), + }); tx.send(client_msg).await.unwrap(); drop(tx); - + let received = timeout(Duration::from_millis(100), transport.next()) .await .unwrap() .unwrap(); - + match received { - RxJsonRpcMessage::::Notification(_) => {}, + RxJsonRpcMessage::::Notification(_) => {} _ => panic!("Unexpected message type"), } } #[tokio::test] async fn test_post_event_handler_session_not_found() { - use axum::extract::{Query, State}; - use axum::Json; - use axum::http::Request; - + use axum::{ + Json, + extract::{Query, State}, + http::Request, + }; + let (app, _) = App::new("/message".to_string(), Duration::from_secs(15)); - + let query = PostEventQuery { session_id: "non-existent".to_string(), }; - + // Create a minimal request parts let request = Request::builder() .method("POST") @@ -466,7 +469,7 @@ mod tests { .body(()) .unwrap(); let (parts, _) = request.into_parts(); - + // Create a simple cancelled notification let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { jsonrpc: crate::model::JsonRpcVersion2_0, @@ -478,21 +481,16 @@ mod tests { reason: None, }, extensions: Default::default(), - } + }, ), }); - - let result = post_event_handler( - State(app), - Query(query), - parts, - Json(client_msg), - ).await; - + + let result = post_event_handler(State(app), Query(query), parts, Json(client_msg)).await; + assert_eq!(result, Err(StatusCode::NOT_FOUND)); } - #[tokio::test] + #[tokio::test] async fn test_server_with_cancellation() { let config = SseServerConfig { bind: "127.0.0.1:0".parse().unwrap(), @@ -501,13 +499,13 @@ mod tests { ct: CancellationToken::new(), sse_keep_alive: None, }; - + let ct_clone = config.ct.clone(); let (mut sse_server, _) = SseServer::new(config); - + // Cancel immediately ct_clone.cancel(); - + // next_transport should return None after cancellation let transport = timeout(Duration::from_millis(100), sse_server.next_transport()).await; assert!(transport.is_ok()); diff --git a/crates/rmcp/src/transport/sse_server/common.rs b/crates/rmcp/src/transport/sse_server/common.rs index 6cfc4f3e..bedfe1c4 100644 --- a/crates/rmcp/src/transport/sse_server/common.rs +++ b/crates/rmcp/src/transport/sse_server/common.rs @@ -1,4 +1,5 @@ use std::{net::SocketAddr, sync::Arc, time::Duration}; + use tokio_util::sync::CancellationToken; pub type SessionId = Arc; @@ -16,4 +17,4 @@ pub struct SseServerConfig { pub sse_keep_alive: Option, } -pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); \ No newline at end of file +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); diff --git a/crates/rmcp/src/transport/sse_server/mod.rs b/crates/rmcp/src/transport/sse_server/mod.rs index 3a2e07aa..96f90031 100644 --- a/crates/rmcp/src/transport/sse_server/mod.rs +++ b/crates/rmcp/src/transport/sse_server/mod.rs @@ -44,10 +44,12 @@ pub mod actix_web; // Convenience type alias #[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] pub use actix_web::SseServer; - -#[cfg(all(feature = "transport-sse-server", feature = "axum", not(feature = "actix-web")))] +#[cfg(all( + feature = "transport-sse-server", + feature = "axum", + not(feature = "actix-web") +))] pub use axum::SseServer; - // Re-export common types when transport-sse-server is enabled #[cfg(feature = "transport-sse-server")] -pub use common::{SseServerConfig, SessionId, session_id, DEFAULT_AUTO_PING_INTERVAL}; \ No newline at end of file +pub use common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index c6046e7f..239c4d82 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -54,19 +54,27 @@ impl Default for StreamableHttpServerConfig { // Axum implementation #[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))) +)] pub mod axum; // Actix-web implementation #[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))) +)] pub mod actix_web; // Export the preferred implementation as StreamableHttpService (without generic parameters) #[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] pub use actix_web::StreamableHttpService; - -#[cfg(all(feature = "transport-streamable-http-server", feature = "axum", not(feature = "actix-web")))] +#[cfg(all( + feature = "transport-streamable-http-server", + feature = "axum", + not(feature = "actix-web") +))] pub use axum::StreamableHttpService; - -pub use session::{SessionId, SessionManager}; \ No newline at end of file +pub use session::{SessionId, SessionManager}; diff --git a/crates/rmcp/src/transport/streamable_http_server/actix_web.rs b/crates/rmcp/src/transport/streamable_http_server/actix_web.rs index 61ad1b25..6f7c36d0 100644 --- a/crates/rmcp/src/transport/streamable_http_server/actix_web.rs +++ b/crates/rmcp/src/transport/streamable_http_server/actix_web.rs @@ -1,7 +1,11 @@ use std::sync::Arc; use actix_web::{ - HttpRequest, HttpResponse, Result, error::InternalError, http::{StatusCode, header}, middleware, web::{self, Bytes, Data}, + HttpRequest, HttpResponse, Result, + error::InternalError, + http::{StatusCode, header}, + middleware, + web::{self, Bytes, Data}, }; use futures::{Stream, StreamExt}; use tokio_stream::wrappers::ReceiverStream; @@ -14,10 +18,9 @@ use crate::{ service::serve_directly, transport::{ OneshotTransport, TransportAdapterIdentity, - common::{ - http_header::{ - EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, HEADER_X_ACCEL_BUFFERING, JSON_MIME_TYPE, - }, + common::http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, + HEADER_X_ACCEL_BUFFERING, JSON_MIME_TYPE, }, }, }; @@ -59,7 +62,7 @@ where .wrap(middleware::NormalizePath::trim()) .route("", web::get().to(Self::handle_get)) .route("", web::post().to(Self::handle_post)) - .route("", web::delete().to(Self::handle_delete)) + .route("", web::delete().to(Self::handle_delete)), ); } } @@ -73,7 +76,7 @@ where .headers() .get(header::ACCEPT) .and_then(|h| h.to_str().ok()); - + if !accept.is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) { return Ok(HttpResponse::NotAcceptable() .body("Not Acceptable: Client must accept text/event-stream")); @@ -85,12 +88,11 @@ where .get(HEADER_SESSION_ID) .and_then(|v| v.to_str().ok()) .map(|s| s.to_owned().into()); - + let Some(session_id) = session_id else { - return Ok(HttpResponse::Unauthorized() - .body("Unauthorized: Session ID is required")); + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required")); }; - + tracing::debug!(%session_id, "GET request for SSE stream"); // Check if session exists @@ -99,10 +101,9 @@ where .has_session(&session_id) .await .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + if !has_session { - return Ok(HttpResponse::Unauthorized() - .body("Unauthorized: Session not found")); + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found")); } // Check if last event id is provided @@ -113,28 +114,33 @@ where .map(|s| s.to_owned()); // Get the appropriate stream - let sse_stream: std::pin::Pin + Send>> = if let Some(last_event_id) = last_event_id { - tracing::debug!(%session_id, %last_event_id, "Resuming stream from last event"); - Box::pin(service - .session_manager - .resume(&session_id, last_event_id) - .await - .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?) - } else { - tracing::debug!(%session_id, "Creating standalone stream"); - Box::pin(service - .session_manager - .create_standalone_stream(&session_id) - .await - .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?) - }; + let sse_stream: std::pin::Pin + Send>> = + if let Some(last_event_id) = last_event_id { + tracing::debug!(%session_id, %last_event_id, "Resuming stream from last event"); + Box::pin( + service + .session_manager + .resume(&session_id, last_event_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?, + ) + } else { + tracing::debug!(%session_id, "Creating standalone stream"); + Box::pin( + service + .session_manager + .create_standalone_stream(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?, + ) + }; // Convert to SSE format let keep_alive = service.config.sse_keep_alive; let sse_stream = async_stream::stream! { let mut stream = sse_stream; let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); - + loop { tokio::select! { Some(msg) = stream.next() => { @@ -181,12 +187,13 @@ where .headers() .get(header::ACCEPT) .and_then(|h| h.to_str().ok()); - + if !accept.is_some_and(|header| { header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE) }) { - return Ok(HttpResponse::NotAcceptable() - .body("Not Acceptable: Client must accept both application/json and text/event-stream")); + return Ok(HttpResponse::NotAcceptable().body( + "Not Acceptable: Client must accept both application/json and text/event-stream", + )); } // Check content type @@ -194,7 +201,7 @@ where .headers() .get(header::CONTENT_TYPE) .and_then(|h| h.to_str().ok()); - + if !content_type.is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) { return Ok(HttpResponse::UnsupportedMediaType() .body("Unsupported Media Type: Content-Type must be application/json")); @@ -203,7 +210,7 @@ where // Deserialize the message let mut message: ClientJsonRpcMessage = serde_json::from_slice(&body) .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))?; - + tracing::debug!(?message, "POST request with message"); if service.config.stateful_mode { @@ -212,21 +219,20 @@ where .headers() .get(HEADER_SESSION_ID) .and_then(|v| v.to_str().ok()); - + if let Some(session_id) = session_id { let session_id = session_id.to_owned().into(); tracing::debug!(%session_id, "POST request with existing session"); - + let has_session = service .session_manager .has_session(&session_id) .await .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + if !has_session { tracing::warn!(%session_id, "Session not found"); - return Ok(HttpResponse::Unauthorized() - .body("Unauthorized: Session not found")); + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found")); } // Note: In actix-web we can't inject request parts like in tower, @@ -238,14 +244,16 @@ where .session_manager .create_stream(&session_id, message) .await - .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + .map_err(|e| { + InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR) + })?; + // Convert to SSE format let keep_alive = service.config.sse_keep_alive; let sse_stream = async_stream::stream! { let mut stream = Box::pin(stream); let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); - + loop { tokio::select! { Some(msg) = stream.next() => { @@ -289,11 +297,14 @@ where .session_manager .accept_message(&session_id, message) .await - .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + .map_err(|e| { + InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR) + })?; + Ok(HttpResponse::Accepted().finish()) } - ClientJsonRpcMessage::BatchRequest(_) | ClientJsonRpcMessage::BatchResponse(_) => { + ClientJsonRpcMessage::BatchRequest(_) + | ClientJsonRpcMessage::BatchResponse(_) => { Ok(HttpResponse::NotImplemented() .body("Batch requests are not supported yet")) } @@ -301,36 +312,39 @@ where } else { // No session id in stateful mode - create new session tracing::debug!("POST request without session, creating new session"); - + let (session_id, transport) = service .session_manager .create_session() .await .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + tracing::info!(%session_id, "Created new session"); - + if let ClientJsonRpcMessage::Request(req) = &mut message { if !matches!(req.request, ClientRequest::InitializeRequest(_)) { - return Ok(HttpResponse::UnprocessableEntity() - .body("Expected initialize request")); + return Ok( + HttpResponse::UnprocessableEntity().body("Expected initialize request") + ); } } else { - return Ok(HttpResponse::UnprocessableEntity() - .body("Expected initialize request")); + return Ok( + HttpResponse::UnprocessableEntity().body("Expected initialize request") + ); } - + let service_instance = service .get_service() .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + // Spawn a task to serve the session tokio::spawn({ let session_manager = service.session_manager.clone(); let session_id = session_id.clone(); async move { let service = serve_server::( - service_instance, transport, + service_instance, + transport, ) .await; match service { @@ -349,14 +363,14 @@ where }); } }); - + // Get initialize response let response = service .session_manager .initialize_session(&session_id, message) .await .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + // Return SSE stream with single response let sse_stream = async_stream::stream! { yield Ok::<_, actix_web::Error>(Bytes::from(format!( @@ -364,7 +378,7 @@ where serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string()) ))); }; - + Ok(HttpResponse::Ok() .content_type(EVENT_STREAM_MIME_TYPE) .append_header((header::CACHE_CONTROL, "no-cache")) @@ -375,39 +389,39 @@ where } else { // Stateless mode tracing::debug!("POST request in stateless mode"); - + match message { ClientJsonRpcMessage::Request(request) => { tracing::debug!(?request, "Processing request in stateless mode"); - + // In stateless mode, handle the request directly let service_instance = service .get_service() .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + let (transport, receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); let service_handle = serve_directly(service_instance, transport, None); - + tokio::spawn(async move { // Let the service process the request let _ = service_handle.waiting().await; }); - + // Convert receiver stream to SSE format let sse_stream = ReceiverStream::new(receiver).map(|message| { tracing::info!(?message); - let data = serde_json::to_string(&message) - .unwrap_or_else(|_| "{}".to_string()); + let data = + serde_json::to_string(&message).unwrap_or_else(|_| "{}".to_string()); Ok::<_, actix_web::Error>(Bytes::from(format!("data: {}\n\n", data))) }); - + // Add keep-alive if configured let keep_alive = service.config.sse_keep_alive; let sse_stream = async_stream::stream! { let mut stream = Box::pin(sse_stream); let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); - + loop { tokio::select! { Some(result) = stream.next() => { @@ -432,17 +446,14 @@ where } } }; - + Ok(HttpResponse::Ok() .content_type(EVENT_STREAM_MIME_TYPE) .append_header((header::CACHE_CONTROL, "no-cache")) .append_header((HEADER_X_ACCEL_BUFFERING, "no")) .streaming(sse_stream)) } - _ => { - Ok(HttpResponse::UnprocessableEntity() - .body("Unexpected message type")) - } + _ => Ok(HttpResponse::UnprocessableEntity().body("Unexpected message type")), } } } @@ -457,12 +468,11 @@ where .get(HEADER_SESSION_ID) .and_then(|v| v.to_str().ok()) .map(|s| s.to_owned().into()); - + let Some(session_id) = session_id else { - return Ok(HttpResponse::Unauthorized() - .body("Unauthorized: Session ID is required")); + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required")); }; - + tracing::debug!(%session_id, "DELETE request to close session"); // Close session @@ -471,9 +481,9 @@ where .close_session(&session_id) .await .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; - + tracing::info!(%session_id, "Session closed"); Ok(HttpResponse::NoContent().finish()) } -} \ No newline at end of file +} diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index 50936d02..f939654f 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -3,7 +3,7 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, tool_handler, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { diff --git a/crates/rmcp/tests/test_sse_server.rs b/crates/rmcp/tests/test_sse_server.rs index 20088d11..1f1c79bf 100644 --- a/crates/rmcp/tests/test_sse_server.rs +++ b/crates/rmcp/tests/test_sse_server.rs @@ -1,16 +1,11 @@ #![cfg(feature = "transport-sse-server")] -use rmcp::{ - ServiceExt, - transport::sse_server::SseServerConfig, -}; - // Import framework-specific types -#[cfg(feature = "axum")] -use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; #[cfg(feature = "actix-web")] use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; - +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; +use rmcp::{ServiceExt, transport::sse_server::SseServerConfig}; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -28,10 +23,14 @@ async fn init() { } // Common test logic for basic SSE server test -async fn test_sse_server_basic_common(bind_addr: std::net::SocketAddr, ct: CancellationToken, service_ct: CancellationToken) -> anyhow::Result<()> { +async fn test_sse_server_basic_common( + bind_addr: std::net::SocketAddr, + ct: CancellationToken, + service_ct: CancellationToken, +) -> anyhow::Result<()> { // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - + // Test that server is running by making a request let client = reqwest::Client::new(); let response = client @@ -39,10 +38,10 @@ async fn test_sse_server_basic_common(bind_addr: std::net::SocketAddr, ct: Cance .header("Accept", "text/event-stream") .send() .await?; - + // SSE endpoint should return OK and start streaming assert_eq!(response.status(), reqwest::StatusCode::OK); - + ct.cancel(); service_ct.cancel(); Ok(()) @@ -52,7 +51,7 @@ async fn test_sse_server_basic_common(bind_addr: std::net::SocketAddr, ct: Cance #[tokio::test] async fn test_axum_sse_server_basic() -> anyhow::Result<()> { init().await; - + let config = SseServerConfig { bind: "127.0.0.1:0".parse()?, sse_path: "/sse".to_string(), @@ -60,12 +59,12 @@ async fn test_axum_sse_server_basic() -> anyhow::Result<()> { ct: CancellationToken::new(), sse_keep_alive: None, }; - + let ct = config.ct.clone(); let sse_server = AxumSseServer::serve_with_config(config).await?; let bind_addr = sse_server.config.bind; let service_ct = sse_server.with_service(Calculator::default); - + test_sse_server_basic_common(bind_addr, ct, service_ct).await } @@ -73,7 +72,7 @@ async fn test_axum_sse_server_basic() -> anyhow::Result<()> { #[actix_web::test] async fn test_actix_sse_server_basic() -> anyhow::Result<()> { init().await; - + let config = SseServerConfig { bind: "127.0.0.1:0".parse()?, sse_path: "/sse".to_string(), @@ -81,126 +80,148 @@ async fn test_actix_sse_server_basic() -> anyhow::Result<()> { ct: CancellationToken::new(), sse_keep_alive: None, }; - + let ct = config.ct.clone(); let sse_server = ActixSseServer::serve_with_config(config).await?; let bind_addr = sse_server.config.bind; let service_ct = sse_server.with_service(Calculator::default); - + test_sse_server_basic_common(bind_addr, ct, service_ct).await } // Common client-server integration test logic #[cfg(feature = "transport-sse-client")] -async fn test_client_server_integration_common(actual_addr: std::net::SocketAddr, ct: CancellationToken) -> anyhow::Result<()> { +async fn test_client_server_integration_common( + actual_addr: std::net::SocketAddr, + ct: CancellationToken, +) -> anyhow::Result<()> { use rmcp::transport::SseClientTransport; - + let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; let client = ().serve(transport).await?; - + // Test basic operations let tools = client.list_all_tools().await?; assert!(!tools.is_empty()); assert_eq!(tools.len(), 2); // sum and sub - + client.cancel().await?; ct.cancel(); Ok(()) } -#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "axum"))] +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "axum" +))] #[tokio::test] async fn test_axum_client_server_integration() -> anyhow::Result<()> { init().await; - + const BIND_ADDRESS: &str = "127.0.0.1:0"; - + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); - + test_client_server_integration_common(actual_addr, ct).await } -#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "actix-web"))] +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "actix-web" +))] #[actix_web::test] async fn test_actix_client_server_integration() -> anyhow::Result<()> { init().await; - + const BIND_ADDRESS: &str = "127.0.0.1:0"; - + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); - + // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - + test_client_server_integration_common(actual_addr, ct).await } // Common concurrent clients test logic #[cfg(feature = "transport-sse-client")] -async fn test_concurrent_clients_common(actual_addr: std::net::SocketAddr, ct: CancellationToken) -> anyhow::Result<()> { +async fn test_concurrent_clients_common( + actual_addr: std::net::SocketAddr, + ct: CancellationToken, +) -> anyhow::Result<()> { use rmcp::transport::SseClientTransport; - + const NUM_CLIENTS: usize = 5; let mut handles = vec![]; - + for i in 0..NUM_CLIENTS { let addr = actual_addr; let handle = tokio::spawn(async move { let transport = SseClientTransport::start(format!("http://{}/sse", addr)).await?; let client = ().serve(transport).await?; - + // Each client does some operations let tools = client.list_all_tools().await?; assert!(!tools.is_empty()); assert_eq!(tools.len(), 2); // sum and sub - + tracing::info!("Client {} completed operations", i); client.cancel().await?; Ok::<(), anyhow::Error>(()) }); handles.push(handle); } - + // Wait for all clients to complete for handle in handles { handle.await??; } - + ct.cancel(); Ok(()) } -#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "axum"))] +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "axum" +))] #[tokio::test] async fn test_axum_concurrent_clients() -> anyhow::Result<()> { init().await; - + const BIND_ADDRESS: &str = "127.0.0.1:0"; - + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); - + test_concurrent_clients_common(actual_addr, ct).await } -#[cfg(all(feature = "transport-sse-server", feature = "transport-sse-client", feature = "actix-web"))] +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "actix-web" +))] #[actix_web::test] async fn test_actix_concurrent_clients() -> anyhow::Result<()> { init().await; - + const BIND_ADDRESS: &str = "127.0.0.1:0"; - + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; let actual_addr = sse_server.config.bind; let ct = sse_server.with_service(Calculator::default); - + // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - + test_concurrent_clients_common(actual_addr, ct).await -} \ No newline at end of file +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 39a0e7ce..6e22f9d7 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -1,20 +1,16 @@ +// Import framework-specific types +#[cfg(feature = "actix-web")] +use rmcp::transport::streamable_http_server::actix_web::StreamableHttpService as ActixStreamableHttpService; +#[cfg(feature = "axum")] +use rmcp::transport::streamable_http_server::axum::StreamableHttpService as AxumStreamableHttpService; use rmcp::{ ServiceExt, service::QuitReason, transport::{ ConfigureCommandExt, SseServer, StreamableHttpClientTransport, StreamableHttpServerConfig, - TokioChildProcess, - streamable_http_server::{ - session::local::LocalSessionManager, - }, + TokioChildProcess, streamable_http_server::session::local::LocalSessionManager, }, }; - -// Import framework-specific types -#[cfg(feature = "axum")] -use rmcp::transport::streamable_http_server::axum::StreamableHttpService as AxumStreamableHttpService; -#[cfg(feature = "actix-web")] -use rmcp::transport::streamable_http_server::actix_web::StreamableHttpService as ActixStreamableHttpService; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; @@ -151,40 +147,45 @@ async fn test_with_js_streamable_http_client_actix() -> anyhow::Result<()> { .wait() .await?; - let service = std::sync::Arc::new(ActixStreamableHttpService::::new( - || Ok(Calculator::new()), - Default::default(), - StreamableHttpServerConfig { - stateful_mode: true, - sse_keep_alive: None, - }, - )); - + let service = std::sync::Arc::new( + ActixStreamableHttpService::::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + }, + ), + ); + let server = actix_web::HttpServer::new(move || { actix_web::App::new() .wrap(actix_web::middleware::Logger::default()) .service( actix_web::web::scope("/mcp") - .configure(ActixStreamableHttpService::configure(service.clone())) + .configure(ActixStreamableHttpService::configure(service.clone())), ) }) .bind(STREAMABLE_HTTP_ACTIX_BIND_ADDRESS)? .run(); - + let server_handle = server.handle(); let server_task = tokio::spawn(server); - + // Give the server a moment to start tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - + let exit_status = tokio::process::Command::new("node") .arg("tests/test_with_js/streamable_client.js") - .arg(format!("http://{}/mcp/", STREAMABLE_HTTP_ACTIX_BIND_ADDRESS)) + .arg(format!( + "http://{}/mcp/", + STREAMABLE_HTTP_ACTIX_BIND_ADDRESS + )) .spawn()? .wait() .await?; assert!(exit_status.success()); - + server_handle.stop(true).await; let _ = server_task.await; Ok(()) diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index c036c6cf..58a84eb3 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,14 +1,12 @@ +// Import framework-specific types +#[cfg(feature = "actix-web")] +use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; use rmcp::{ ServiceExt, transport::{ConfigureCommandExt, TokioChildProcess, sse_server::SseServerConfig}, }; - -// Import framework-specific types -#[cfg(feature = "axum")] -use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; -#[cfg(feature = "actix-web")] -use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; - use tokio::time::timeout; use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -33,7 +31,10 @@ async fn init() -> anyhow::Result<()> { } // Common test logic for Python client -async fn test_with_python_client_common(bind_address: &str, ct: CancellationToken) -> anyhow::Result<()> { +async fn test_with_python_client_common( + bind_address: &str, + ct: CancellationToken, +) -> anyhow::Result<()> { let status = tokio::process::Command::new("uv") .arg("run") .arg("client.py") @@ -80,7 +81,7 @@ async fn test_with_python_client_actix() -> anyhow::Result<()> { #[tokio::test] async fn test_nested_with_python_client() -> anyhow::Result<()> { use axum::Router; - + init().await?; const BIND_ADDRESS: &str = "127.0.0.1:8001"; diff --git a/examples/servers/src/complex_auth_sse.rs b/examples/servers/src/complex_auth_sse.rs index 75488657..a713fda8 100644 --- a/examples/servers/src/complex_auth_sse.rs +++ b/examples/servers/src/complex_auth_sse.rs @@ -13,12 +13,11 @@ use axum::{ }; use rand::{Rng, distr::Alphanumeric}; use rmcp::transport::{ - SseServer, auth::{ AuthorizationMetadata, ClientRegistrationRequest, ClientRegistrationResponse, OAuthClientConfig, }, - sse_server::SseServerConfig, + sse_server::{SseServerConfig, axum::SseServer}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs index 6312180d..7c1a535a 100644 --- a/examples/servers/src/counter_hyper_streamable_http.rs +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -1,3 +1,4 @@ +// Example of using streamable HTTP server transport with hyper/tower mod common; use common::counter::Counter; use hyper_util::{ @@ -6,7 +7,7 @@ use hyper_util::{ service::TowerToHyperService, }; use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, + axum::StreamableHttpService, session::local::LocalSessionManager, }; #[tokio::main] diff --git a/examples/servers/src/counter_sse.rs b/examples/servers/src/counter_sse.rs index 373a8a6d..a62d6ec2 100644 --- a/examples/servers/src/counter_sse.rs +++ b/examples/servers/src/counter_sse.rs @@ -1,4 +1,6 @@ -use rmcp::transport::sse_server::{SseServer, SseServerConfig}; +// Example of using SSE server transport with axum framework +// This requires the "axum" feature to be enabled in Cargo.toml +use rmcp::transport::sse_server::{SseServerConfig, axum::SseServer}; use tracing_subscriber::{ layer::SubscriberExt, util::SubscriberInitExt, @@ -27,28 +29,27 @@ async fn main() -> anyhow::Result<()> { sse_keep_alive: None, }; - let (sse_server, router) = SseServer::new(config); + let ct_signal = config.ct.clone(); - // Do something with the router, e.g., add routes or middleware - - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; - - let ct = sse_server.config.ct.child_token(); + // When axum feature is enabled, use axum-specific SseServer + let sse_server = SseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let ct = sse_server.with_service(Counter::new); - let server = axum::serve(listener, router).with_graceful_shutdown(async move { - ct.cancelled().await; - tracing::info!("sse server cancelled"); - }); + println!("\nšŸš€ SSE Server (axum) running at http://{}", bind_addr); + println!("šŸ“” SSE endpoint: http://{}/sse", bind_addr); + println!("šŸ“® Message endpoint: http://{}/message", bind_addr); + println!("\nPress Ctrl+C to stop the server\n"); + // Set up Ctrl-C handler tokio::spawn(async move { - if let Err(e) = server.await { - tracing::error!(error = %e, "sse server shutdown with error"); - } + tokio::signal::ctrl_c().await.ok(); + println!("\nā¹ļø Shutting down..."); + ct_signal.cancel(); }); - let ct = sse_server.with_service(Counter::new); - - tokio::signal::ctrl_c().await?; - ct.cancel(); + // Wait for cancellation + ct.cancelled().await; + println!("āœ… Server stopped"); Ok(()) } diff --git a/examples/servers/src/counter_sse_actix.rs b/examples/servers/src/counter_sse_actix.rs index f1c6af86..622fc49f 100644 --- a/examples/servers/src/counter_sse_actix.rs +++ b/examples/servers/src/counter_sse_actix.rs @@ -32,14 +32,17 @@ async fn main() -> anyhow::Result<()> { }; let ct_signal = config.ct.clone(); - + // When actix-web feature is enabled, SseServer uses actix-web implementation // The same API works with both axum and actix-web let sse_server = SseServer::serve_with_config(config).await?; let bind_addr = sse_server.config.bind; let ct = sse_server.with_service(Counter::new); - println!("\nšŸš€ SSE Server (actix-web) running at http://{}", bind_addr); + println!( + "\nšŸš€ SSE Server (actix-web) running at http://{}", + bind_addr + ); println!("šŸ“” SSE endpoint: http://{}/sse", bind_addr); println!("šŸ“® Message endpoint: http://{}/message", bind_addr); println!("\nPress Ctrl+C to stop the server\n"); @@ -55,4 +58,4 @@ async fn main() -> anyhow::Result<()> { ct.cancelled().await; println!("āœ… Server stopped"); Ok(()) -} \ No newline at end of file +} diff --git a/examples/servers/src/counter_streamable_http_actix.rs b/examples/servers/src/counter_streamable_http_actix.rs index 1fa57d9a..fb7afa03 100644 --- a/examples/servers/src/counter_streamable_http_actix.rs +++ b/examples/servers/src/counter_streamable_http_actix.rs @@ -19,7 +19,7 @@ async fn main() -> anyhow::Result<()> { .init(); let bind_addr = "127.0.0.1:8080"; - + // Create the streamable HTTP service // When actix-web feature is enabled, StreamableHttpService uses actix-web implementation let service = Arc::new(StreamableHttpService::new( @@ -45,4 +45,4 @@ async fn main() -> anyhow::Result<()> { .await?; Ok(()) -} \ No newline at end of file +} diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index ff00cec6..b4d8bd65 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -1,5 +1,6 @@ +// Example of using streamable HTTP server transport with axum framework use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, + axum::StreamableHttpService, session::local::LocalSessionManager, }; use tracing_subscriber::{ layer::SubscriberExt, diff --git a/examples/servers/src/simple_auth_sse.rs b/examples/servers/src/simple_auth_sse.rs index 6514c161..2539a9a3 100644 --- a/examples/servers/src/simple_auth_sse.rs +++ b/examples/servers/src/simple_auth_sse.rs @@ -16,7 +16,7 @@ use axum::{ response::{Html, Response}, routing::get, }; -use rmcp::transport::{SseServer, sse_server::SseServerConfig}; +use rmcp::transport::sse_server::{SseServerConfig, axum::SseServer}; use tokio_util::sync::CancellationToken; mod common; use common::counter::Counter;