diff --git a/README.md b/README.md index 3cb91ceb..3c822c44 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,17 @@ rmcp = { version = "0.1", features = ["server"] } ## or dev channel rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main" } ``` + +#### Web Framework Choice +By default, rmcp uses [axum](https://github.com/tokio-rs/axum) for web server transports (SSE and streamable HTTP). You can also use [actix-web](https://github.com/actix/actix-web) as an alternative: + +```toml +# For actix-web support +rmcp = { version = "0.1", features = ["server", "actix-web"] } +``` + +**Note**: When both `axum` and `actix-web` features are enabled, actix-web takes precedence for convenience type aliases. + ### Third Dependencies Basic dependencies: - [tokio required](https://github.com/tokio-rs/tokio) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 2b9126d1..6420d107 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -55,12 +55,15 @@ process-wrap = { version = "8.2", features = ["tokio1"], optional = true } # for http-server transport axum = { version = "0.8", features = [], optional = true } +actix-web = { version = "4", optional = true } +actix-rt = { version = "2", optional = true } rand = { version = "0.9", optional = true } tokio-stream = { version = "0.1", optional = true } uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } +async-stream = { version = "0.3", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] @@ -70,10 +73,13 @@ chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", default-features = false, features = ["serde", "clock", "std", "oldtime"] } [features] -default = ["base64", "macros", "server"] +default = ["base64", "macros", "server", "axum"] client = ["dep:tokio-stream"] server = ["transport-async-rw", "dep:schemars"] macros = ["dep:rmcp-macros", "dep:paste"] +# Web framework features +axum = ["dep:axum"] +actix-web = ["dep:actix-web", "dep:actix-rt", "dep:async-stream"] # reqwest http client __reqwest = ["dep:reqwest"] @@ -116,7 +122,6 @@ transport-sse-server = [ "transport-async-rw", "transport-worker", "server-side-http", - "dep:axum", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", @@ -125,6 +130,7 @@ transport-streamable-http-server = [ transport-streamable-http-server-session = [ "transport-async-rw", "dep:tokio-stream", + "transport-worker", ] # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 7c93dd24..38f2ce25 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -195,6 +195,9 @@ RMCP uses feature flags to control which components are included: - `client`: Enable client functionality - `server`: Enable server functionality and the tool system - `macros`: Enable the `#[tool]` macro (enabled by default) +- Web framework features: + - `axum`: Axum web framework support (enabled by default) + - `actix-web`: actix-web framework support as an alternative to axum - Transport-specific features: - `transport-async-rw`: Async read/write support - `transport-io`: I/O stream support @@ -204,15 +207,65 @@ RMCP uses feature flags to control which components are included: - `auth`: OAuth2 authentication support - `schemars`: JSON Schema generation (for tool definitions) +**Note**: When both `axum` and `actix-web` features are enabled, actix-web implementations take precedence for convenience type aliases. + + +## Web Framework Support + +SSE and streamable HTTP server transports support multiple web frameworks: + +### Using axum (default) +```rust +use rmcp::transport::SseServer; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; + let ct = server.with_service(|| Ok(MyService::new())); + ct.cancelled().await; + Ok(()) +} +``` + +### Using actix-web +Enable the `actix-web` feature in your `Cargo.toml`: +```toml +rmcp = { version = "0.1", features = ["server", "actix-web"] } +``` + +Then use with actix-web runtime: +```rust +use rmcp::transport::SseServer; + +#[actix_web::main] +async fn main() -> Result<(), Box> { + let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; + let ct = server.with_service(|| Ok(MyService::new())); + ct.cancelled().await; + Ok(()) +} +``` + +### Framework-specific imports +When you need to use a specific framework implementation: +```rust +// For axum +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer; + +// For actix-web +#[cfg(feature = "actix-web")] +use rmcp::transport::sse_server::actix_web::SseServer; +``` ## Transports - `transport-io`: Server stdio transport -- `transport-sse-server`: Server SSE transport +- `transport-sse-server`: Server SSE transport (supports both axum and actix-web) - `transport-child-process`: Client stdio transport - `transport-sse-client`: Client sse transport -- `transport-streamable-http-server` streamable http server transport -- `transport-streamable-http-client` streamable http client transport +- `transport-streamable-http-server`: Streamable HTTP server transport (supports both axum and actix-web) +- `transport-streamable-http-client`: Streamable HTTP client transport
Transport diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index a3d1a414..cbd62499 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -51,6 +51,23 @@ //! Next also implement [ServerHandler] for `Counter` and start the server inside //! `main` by calling `Counter::new().serve(...)`. See the examples directory in the repository for more information. //! +//! ### Web Framework Support +//! +//! Server transports (SSE and streamable HTTP) support both axum (default) and actix-web: +//! +//! ```rust,ignore +//! // Using actix-web (requires actix-web feature) +//! use rmcp::{ServiceExt, transport::SseServer}; +//! +//! #[actix_web::main] +//! async fn main() -> Result<(), Box> { +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! let ct = server.with_service(Counter::new); +//! ct.cancelled().await; +//! Ok(()) +//! } +//! ``` +//! //! ## Client //! //! A client can be used to interact with a server. Clients can be used to get a diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 20e6ce75..2cf94012 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -124,7 +124,7 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized pub mod streamable_http_server; #[cfg(feature = "transport-streamable-http-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService}; +pub use streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService}; #[cfg(feature = "transport-streamable-http-client")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 84bc7bfb..275c9f18 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -1,4 +1,5 @@ pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +pub const HEADER_X_ACCEL_BUFFERING: &str = "X-Accel-Buffering"; pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; pub const JSON_MIME_TYPE: &str = "application/json"; diff --git a/crates/rmcp/src/transport/sse_server/actix_web.rs b/crates/rmcp/src/transport/sse_server/actix_web.rs new file mode 100644 index 00000000..6fbab232 --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/actix_web.rs @@ -0,0 +1,611 @@ +use std::{collections::HashMap, io, sync::Arc, time::Duration}; + +use actix_web::{ + HttpRequest, HttpResponse, Result, Scope, + error::ErrorInternalServerError, + middleware, + web::{self, Bytes, Data, Json, Query}, +}; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::{CancellationToken, PollSender}; +use tracing::Instrument; + +use super::common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; +use crate::{ + RoleServer, Service, + model::ClientJsonRpcMessage, + service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, + transport::common::http_header::HEADER_X_ACCEL_BUFFERING, +}; + +type TxStore = + Arc>>>; + +#[derive(Clone, Debug)] +struct AppData { + txs: TxStore, + transport_tx: tokio::sync::mpsc::UnboundedSender, + post_path: Arc, + sse_ping_interval: Duration, +} + +impl AppData { + pub fn new( + post_path: String, + sse_ping_interval: Duration, + ) -> ( + Self, + tokio::sync::mpsc::UnboundedReceiver, + ) { + let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); + ( + Self { + txs: Default::default(), + transport_tx, + post_path: post_path.into(), + sse_ping_interval, + }, + transport_rx, + ) + } +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + app_data: Data, + query: Query, + _req: HttpRequest, + message: Json, +) -> Result { + let session_id = &query.session_id; + tracing::debug!(session_id, ?message, "new client message"); + + let tx = { + let rg = app_data.txs.read().await; + rg.get(session_id.as_str()) + .ok_or_else(|| actix_web::error::ErrorNotFound("Session not found"))? + .clone() + }; + + // Note: In actix-web, we don't have direct access to modify extensions + // This would need a different approach for passing HTTP request context + + if tx.send(message.0).await.is_err() { + tracing::error!("send message error"); + return Err(actix_web::error::ErrorGone("Session closed")); + } + + Ok(HttpResponse::Accepted().finish()) +} + +async fn sse_handler(app_data: Data, _req: HttpRequest) -> Result { + let session = session_id(); + tracing::info!(%session, "sse connection"); + + let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); + let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64); + let to_client_tx_clone = to_client_tx.clone(); + + app_data + .txs + .write() + .await + .insert(session.clone(), from_client_tx); + + let _session_id = session.clone(); + let stream = ReceiverStream::new(from_client_rx); + let sink = PollSender::new(to_client_tx); + let transport = SseServerTransport { + stream, + sink, + session_id: session.clone(), + tx_store: app_data.txs.clone(), + }; + + let transport_send_result = app_data.transport_tx.send(transport); + if transport_send_result.is_err() { + tracing::warn!("send transport out error"); + return Err(ErrorInternalServerError( + "Failed to send transport, server is closed", + )); + } + + let post_path = app_data.post_path.clone(); + let ping_interval = app_data.sse_ping_interval; + let session_for_stream = session.clone(); + + // Create SSE response stream + let sse_stream = async_stream::stream! { + // Send initial endpoint message + yield Ok::<_, actix_web::Error>(Bytes::from(format!( + "event: endpoint\ndata: {}?sessionId={}\n\n", + post_path, session_for_stream + ))); + + // Set up ping interval + let mut ping_interval = tokio::time::interval(ping_interval); + ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let mut rx = ReceiverStream::new(to_client_rx); + + loop { + tokio::select! { + Some(message) = rx.next() => { + match serde_json::to_string(&message) { + Ok(json) => { + yield Ok(Bytes::from(format!("event: message\ndata: {}\n\n", json))); + } + Err(e) => { + tracing::error!("Failed to serialize message: {}", e); + } + } + } + _ = ping_interval.tick() => { + yield Ok(Bytes::from(": ping\n\n")); + } + else => break, + } + } + }; + + // Clean up on disconnect + let app_data_clone = app_data.clone(); + let session_for_cleanup = session.clone(); + actix_rt::spawn(async move { + to_client_tx_clone.closed().await; + + let mut txs = app_data_clone.txs.write().await; + txs.remove(&session_for_cleanup); + tracing::debug!(%session_for_cleanup, "Closed session and cleaned up resources"); + }); + + Ok(HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header((HEADER_X_ACCEL_BUFFERING, "no")) + .streaming(sse_stream)) +} + +pub struct SseServerTransport { + stream: ReceiverStream>, + sink: PollSender>, + session_id: SessionId, + tx_store: TxStore, +} + +impl Sink> for SseServerTransport { + type Error = io::Error; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_ready_unpin(cx) + .map_err(std::io::Error::other) + } + + fn start_send( + mut self: std::pin::Pin<&mut Self>, + item: TxJsonRpcMessage, + ) -> Result<(), Self::Error> { + self.sink + .start_send_unpin(item) + .map_err(std::io::Error::other) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink + .poll_flush_unpin(cx) + .map_err(std::io::Error::other) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let inner_close_result = self + .sink + .poll_close_unpin(cx) + .map_err(std::io::Error::other); + if inner_close_result.is_ready() { + let session_id = self.session_id.clone(); + let tx_store = self.tx_store.clone(); + tokio::spawn(async move { + tx_store.write().await.remove(&session_id); + }); + } + inner_close_result + } +} + +impl Stream for SseServerTransport { + type Item = RxJsonRpcMessage; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } +} + +#[derive(Debug)] +pub struct SseServer { + transport_rx: Arc>>, + pub config: SseServerConfig, + app_data: Data, +} + +impl SseServer { + pub async fn serve(bind: std::net::SocketAddr) -> io::Result { + Self::serve_with_config(SseServerConfig { + bind, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }) + .await + } + + pub async fn serve_with_config(mut config: SseServerConfig) -> io::Result { + let bind_addr = config.bind; + let ct = config.ct.clone(); + + // First bind to get the actual address + let listener = std::net::TcpListener::bind(bind_addr)?; + let actual_addr = listener.local_addr()?; + listener.set_nonblocking(true)?; + + // Update config with actual address + config.bind = actual_addr; + let (sse_server, _) = Self::new(config); + let app_data = sse_server.app_data.clone(); + let sse_path = sse_server.config.sse_path.clone(); + let post_path = sse_server.config.post_path.clone(); + + let server = actix_web::HttpServer::new(move || { + actix_web::App::new() + .app_data(app_data.clone()) + .wrap(middleware::NormalizePath::trim()) + .route(&sse_path, web::get().to(sse_handler)) + .route(&post_path, web::post().to(post_event_handler)) + }) + .listen(listener)? + .run(); + + let ct_child = ct.child_token(); + let server_handle = server.handle(); + + actix_rt::spawn(async move { + ct_child.cancelled().await; + tracing::info!("sse server cancelled"); + server_handle.stop(true).await; + }); + + actix_rt::spawn( + async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "sse server shutdown with error"); + } + } + .instrument(tracing::info_span!("sse-server", bind_address = %actual_addr)), + ); + + Ok(sse_server) + } + + pub fn new(config: SseServerConfig) -> (SseServer, Scope) { + let (app_data, transport_rx) = AppData::new( + config.post_path.clone(), + config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), + ); + + let sse_path = config.sse_path.clone(); + let post_path = config.post_path.clone(); + + let app_data = Data::new(app_data); + + let scope = web::scope("") + .app_data(app_data.clone()) + .route(&sse_path, web::get().to(sse_handler)) + .route(&post_path, web::post().to(post_event_handler)); + + let server = SseServer { + transport_rx: Arc::new(Mutex::new(transport_rx)), + config, + app_data, + }; + + (server, scope) + } + + pub fn with_service(self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + use crate::service::ServiceExt; + let ct = self.config.ct.clone(); + let transport_rx = self.transport_rx.clone(); + + actix_rt::spawn(async move { + while let Some(transport) = transport_rx.lock().await.recv().await { + let service = service_provider(); + let ct_child = ct.child_token(); + tokio::spawn(async move { + let server = service + .serve_with_ct(transport, ct_child) + .await + .map_err(std::io::Error::other)?; + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + self.config.ct.clone() + } + + /// This allows you to skip the initialization steps for incoming request. + pub fn with_service_directly(self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + let ct = self.config.ct.clone(); + let transport_rx = self.transport_rx.clone(); + + actix_rt::spawn(async move { + while let Some(transport) = transport_rx.lock().await.recv().await { + let service = service_provider(); + let ct_child = ct.child_token(); + tokio::spawn(async move { + let server = serve_directly_with_ct(service, transport, None, ct_child); + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + self.config.ct.clone() + } + + pub fn cancel(&self) { + self.config.ct.cancel(); + } + + pub async fn next_transport(&self) -> Option { + self.transport_rx.lock().await.recv().await + } +} + +impl Stream for SseServer { + type Item = SseServerTransport; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let mut rx = match self.transport_rx.try_lock() { + Ok(rx) => rx, + Err(_) => { + cx.waker().wake_by_ref(); + return std::task::Poll::Pending; + } + }; + rx.poll_recv(cx) + } +} + +#[cfg(test)] +mod tests { + use futures::{SinkExt, StreamExt}; + use tokio::time::timeout; + + use super::*; + + #[tokio::test] + async fn test_session_management() { + let (app_data, transport_rx) = + AppData::new("/message".to_string(), Duration::from_secs(15)); + + // Create a session + let session_id = session_id(); + let (tx, _rx) = tokio::sync::mpsc::channel(64); + + // Insert session + app_data.txs.write().await.insert(session_id.clone(), tx); + + // Verify session exists + assert!(app_data.txs.read().await.contains_key(&session_id)); + + // Remove session + app_data.txs.write().await.remove(&session_id); + + // Verify session removed + assert!(!app_data.txs.read().await.contains_key(&session_id)); + + drop(transport_rx); + } + + #[actix_web::test] + async fn test_sse_server_creation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + let (sse_server, scope) = SseServer::new(config); + + assert_eq!(sse_server.config.sse_path, "/sse"); + assert_eq!(sse_server.config.post_path, "/message"); + + // Scope should be properly configured + drop(scope); // Just ensure it's created without panic + } + + #[tokio::test] + async fn test_transport_stream() { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let stream = ReceiverStream::new(rx); + let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); + let sink = PollSender::new(sink_tx); + + let mut transport = SseServerTransport { + stream, + sink, + session_id: session_id(), + tx_store: Default::default(), + }; + + // Test sending through transport + use crate::model::{EmptyResult, JsonRpcMessage, ServerResult}; + let msg: TxJsonRpcMessage = + JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); + // For PollSender, we need to send through async context + transport.send(msg).await.unwrap(); + + // Should receive the message + let received = timeout(Duration::from_millis(100), sink_rx.recv()) + .await + .unwrap() + .unwrap(); + + match received { + TxJsonRpcMessage::::Response(_) => {} + _ => panic!("Unexpected message type"), + } + + // Test receiving through transport + let client_msg: RxJsonRpcMessage = + crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + }, + ), + }); + tx.send(client_msg).await.unwrap(); + drop(tx); + + let received = timeout(Duration::from_millis(100), transport.next()) + .await + .unwrap() + .unwrap(); + + match received { + RxJsonRpcMessage::::Notification(_) => {} + _ => panic!("Unexpected message type"), + } + } + + #[actix_web::test] + async fn test_post_event_handler_session_not_found() { + use actix_web::test; + + let (app_data, _) = AppData::new("/message".to_string(), Duration::from_secs(15)); + let app_data = Data::new(app_data); + + let query = PostEventQuery { + session_id: "non-existent".to_string(), + }; + + // Create a simple cancelled notification + let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + }, + ), + }); + + let result = post_event_handler( + app_data, + Query(query), + test::TestRequest::default().to_http_request(), + Json(client_msg), + ) + .await; + + assert!(result.is_err()); + } + + #[actix_web::test] + async fn test_server_with_cancellation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let (sse_server, _) = SseServer::new(config); + + // Test that the cancellation token is properly connected + assert!(!ct.is_cancelled()); + ct.cancel(); + assert!(ct.is_cancelled()); + + // Verify server config + assert!(sse_server.config.ct.is_cancelled()); + } + + #[actix_web::test] + async fn test_sse_stream_generation() { + let (app_data, mut transport_rx) = + AppData::new("/message".to_string(), Duration::from_secs(15)); + let app_data = Data::new(app_data); + + // Call SSE handler + let result = sse_handler( + app_data.clone(), + actix_web::test::TestRequest::default().to_http_request(), + ) + .await; + + assert!(result.is_ok()); + let response = result.unwrap(); + + // Check response headers + assert_eq!(response.status(), actix_web::http::StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + // Verify a transport was created + let transport = transport_rx.try_recv(); + assert!(transport.is_ok()); + } +} diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server/axum.rs similarity index 61% rename from crates/rmcp/src/transport/sse_server.rs rename to crates/rmcp/src/transport/sse_server/axum.rs index 15a65cb5..f24a9fd5 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server/axum.rs @@ -15,16 +15,15 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::{CancellationToken, PollSender}; use tracing::Instrument; +use super::common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, - transport::common::server_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; type TxStore = Arc>>>; -pub type TransportReceiver = ReceiverStream>; #[derive(Clone)] struct App { @@ -214,15 +213,6 @@ impl Stream for SseServerTransport { } } -#[derive(Debug, Clone)] -pub struct SseServerConfig { - pub bind: SocketAddr, - pub sse_path: String, - pub post_path: String, - pub ct: CancellationToken, - pub sse_keep_alive: Option, -} - #[derive(Debug)] pub struct SseServer { transport_rx: tokio::sync::mpsc::UnboundedReceiver, @@ -240,9 +230,11 @@ impl SseServer { }) .await } - pub async fn serve_with_config(config: SseServerConfig) -> io::Result { + pub async fn serve_with_config(mut config: SseServerConfig) -> io::Result { + let listener = tokio::net::TcpListener::bind(config.bind).await?; + // Update config with actual bound address (important when port is 0) + config.bind = listener.local_addr()?; let (sse_server, service) = Self::new(config); - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; let ct = sse_server.config.ct.child_token(); let server = axum::serve(listener, service).with_graceful_shutdown(async move { ct.cancelled().await; @@ -341,3 +333,182 @@ impl Stream for SseServer { self.transport_rx.poll_recv(cx) } } + +#[cfg(test)] +mod tests { + use futures::{SinkExt, StreamExt}; + use tokio::time::timeout; + + use super::*; + + #[tokio::test] + async fn test_session_management() { + let (app, transport_rx) = App::new("/message".to_string(), Duration::from_secs(15)); + + // Create a session + let session_id = session_id(); + let (tx, _rx) = tokio::sync::mpsc::channel(64); + + // Insert session + app.txs.write().await.insert(session_id.clone(), tx); + + // Verify session exists + assert!(app.txs.read().await.contains_key(&session_id)); + + // Remove session + app.txs.write().await.remove(&session_id); + + // Verify session removed + assert!(!app.txs.read().await.contains_key(&session_id)); + + drop(transport_rx); + } + + #[tokio::test] + async fn test_sse_server_creation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + let (sse_server, router) = SseServer::new(config); + + assert_eq!(sse_server.config.sse_path, "/sse"); + assert_eq!(sse_server.config.post_path, "/message"); + + // Router should be properly configured + drop(router); // Just ensure it's created without panic + } + + #[tokio::test] + async fn test_transport_stream() { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let stream = ReceiverStream::new(rx); + let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(1); + let sink = PollSender::new(sink_tx); + + let mut transport = SseServerTransport { + stream, + sink, + session_id: session_id(), + tx_store: Default::default(), + }; + + // Test sending through transport + use crate::model::{EmptyResult, JsonRpcMessage, ServerResult}; + let msg: TxJsonRpcMessage = + JsonRpcMessage::Response(crate::model::JsonRpcResponse { + jsonrpc: crate::model::JsonRpcVersion2_0, + id: crate::model::NumberOrString::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }); + // For PollSender, we need to send through async context + transport.send(msg).await.unwrap(); + + // Should receive the message + let received = timeout(Duration::from_millis(100), sink_rx.recv()) + .await + .unwrap() + .unwrap(); + + match received { + TxJsonRpcMessage::::Response(_) => {} + _ => panic!("Unexpected message type"), + } + + // Test receiving through transport + let client_msg: RxJsonRpcMessage = + crate::model::JsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + }, + ), + }); + tx.send(client_msg).await.unwrap(); + drop(tx); + + let received = timeout(Duration::from_millis(100), transport.next()) + .await + .unwrap() + .unwrap(); + + match received { + RxJsonRpcMessage::::Notification(_) => {} + _ => panic!("Unexpected message type"), + } + } + + #[tokio::test] + async fn test_post_event_handler_session_not_found() { + use axum::{ + Json, + extract::{Query, State}, + http::Request, + }; + + let (app, _) = App::new("/message".to_string(), Duration::from_secs(15)); + + let query = PostEventQuery { + session_id: "non-existent".to_string(), + }; + + // Create a minimal request parts + let request = Request::builder() + .method("POST") + .uri("/message") + .body(()) + .unwrap(); + let (parts, _) = request.into_parts(); + + // Create a simple cancelled notification + let client_msg = ClientJsonRpcMessage::Notification(crate::model::JsonRpcNotification { + jsonrpc: crate::model::JsonRpcVersion2_0, + notification: crate::model::ClientNotification::CancelledNotification( + crate::model::Notification { + method: crate::model::CancelledNotificationMethod, + params: crate::model::CancelledNotificationParam { + request_id: crate::model::NumberOrString::Number(1), + reason: None, + }, + extensions: Default::default(), + }, + ), + }); + + let result = post_event_handler(State(app), Query(query), parts, Json(client_msg)).await; + + assert_eq!(result, Err(StatusCode::NOT_FOUND)); + } + + #[tokio::test] + async fn test_server_with_cancellation() { + let config = SseServerConfig { + bind: "127.0.0.1:0".parse().unwrap(), + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct_clone = config.ct.clone(); + let (mut sse_server, _) = SseServer::new(config); + + // Cancel immediately + ct_clone.cancel(); + + // next_transport should return None after cancellation + let transport = timeout(Duration::from_millis(100), sse_server.next_transport()).await; + assert!(transport.is_ok()); + assert!(transport.unwrap().is_none()); + } +} diff --git a/crates/rmcp/src/transport/sse_server/common.rs b/crates/rmcp/src/transport/sse_server/common.rs new file mode 100644 index 00000000..bedfe1c4 --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/common.rs @@ -0,0 +1,20 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use tokio_util::sync::CancellationToken; + +pub type SessionId = Arc; + +pub fn session_id() -> SessionId { + uuid::Uuid::new_v4().to_string().into() +} + +#[derive(Debug, Clone)] +pub struct SseServerConfig { + pub bind: SocketAddr, + pub sse_path: String, + pub post_path: String, + pub ct: CancellationToken, + pub sse_keep_alive: Option, +} + +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); diff --git a/crates/rmcp/src/transport/sse_server/mod.rs b/crates/rmcp/src/transport/sse_server/mod.rs new file mode 100644 index 00000000..96f90031 --- /dev/null +++ b/crates/rmcp/src/transport/sse_server/mod.rs @@ -0,0 +1,55 @@ +//! SSE Server Transport Module +//! +//! This module provides Server-Sent Events (SSE) transport implementations for MCP. +//! +//! # Module Organization +//! +//! Framework-specific implementations are organized in submodules: +//! - `axum` - Contains the Axum-based SSE server implementation +//! - `actix_web` - Contains the actix-web-based SSE server implementation +//! +//! Each submodule exports a `SseServer` type with the same interface. +//! +//! For convenience, a type alias `SseServer` is provided at the module root that resolves to: +//! - `actix_web::SseServer` when the `actix-web` feature is enabled +//! - `axum::SseServer` when only the `axum` feature is enabled +//! +//! # Examples +//! +//! Using the convenience alias (recommended for most use cases): +//! ```ignore +//! use rmcp::transport::SseServer; +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! ``` +//! +//! Using framework-specific modules (when you need a specific implementation): +//! ```ignore +//! #[cfg(feature = "axum")] +//! use rmcp::transport::sse_server::axum::SseServer; +//! #[cfg(feature = "axum")] +//! let server = SseServer::serve("127.0.0.1:8080".parse()?).await?; +//! ``` + +#[cfg(feature = "transport-sse-server")] +pub mod common; + +// Axum implementation +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +pub mod axum; + +// Actix-web implementation +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +pub mod actix_web; + +// Convenience type alias +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +pub use actix_web::SseServer; +#[cfg(all( + feature = "transport-sse-server", + feature = "axum", + not(feature = "actix-web") +))] +pub use axum::SseServer; +// Re-export common types when transport-sse-server is enabled +#[cfg(feature = "transport-sse-server")] +pub use common::{DEFAULT_AUTO_PING_INTERVAL, SessionId, SseServerConfig, session_id}; diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index 733fc5e5..239c4d82 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -1,8 +1,80 @@ +//! Streamable HTTP Server Transport Module +//! +//! This module provides streamable HTTP transport implementations for MCP. +//! +//! # Module Organization +//! +//! Framework-specific implementations are organized in submodules: +//! - `axum` - Contains the Axum-based streamable HTTP service implementation +//! - `actix_web` - Contains the actix-web-based streamable HTTP service implementation +//! +//! Each submodule exports a `StreamableHttpService` type with the same interface. +//! +//! For convenience, a type alias `StreamableHttpService` is provided at the module root that resolves to: +//! - `actix_web::StreamableHttpService` when the `actix-web` feature is enabled +//! - `axum::StreamableHttpService` when only the `axum` feature is enabled +//! +//! # Examples +//! +//! Using the convenience alias (recommended for most use cases): +//! ```ignore +//! use rmcp::transport::StreamableHttpService; +//! let service = StreamableHttpService::new(|| Ok(handler), session_manager, config); +//! ``` +//! +//! Using framework-specific modules (when you need a specific implementation): +//! ```ignore +//! #[cfg(feature = "axum")] +//! use rmcp::transport::streamable_http_server::axum::StreamableHttpService; +//! #[cfg(feature = "axum")] +//! let service = StreamableHttpService::new(|| Ok(handler), session_manager, config); +//! ``` + pub mod session; -#[cfg(feature = "transport-streamable-http-server")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub mod tower; + +use std::time::Duration; + +/// Configuration for the streamable HTTP server +#[derive(Debug, Clone)] +pub struct StreamableHttpServerConfig { + /// The ping message duration for SSE connections. + pub sse_keep_alive: Option, + /// If true, the server will create a session for each request and keep it alive. + pub stateful_mode: bool, +} + +impl Default for StreamableHttpServerConfig { + fn default() -> Self { + Self { + sse_keep_alive: Some(Duration::from_secs(15)), + stateful_mode: true, + } + } +} + +// Axum implementation +#[cfg(all(feature = "transport-streamable-http-server", feature = "axum"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "transport-streamable-http-server", feature = "axum"))) +)] +pub mod axum; + +// Actix-web implementation +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))) +)] +pub mod actix_web; + +// Export the preferred implementation as StreamableHttpService (without generic parameters) +#[cfg(all(feature = "transport-streamable-http-server", feature = "actix-web"))] +pub use actix_web::StreamableHttpService; +#[cfg(all( + feature = "transport-streamable-http-server", + feature = "axum", + not(feature = "actix-web") +))] +pub use axum::StreamableHttpService; pub use session::{SessionId, SessionManager}; -#[cfg(feature = "transport-streamable-http-server")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; diff --git a/crates/rmcp/src/transport/streamable_http_server/actix_web.rs b/crates/rmcp/src/transport/streamable_http_server/actix_web.rs new file mode 100644 index 00000000..6f7c36d0 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/actix_web.rs @@ -0,0 +1,489 @@ +use std::sync::Arc; + +use actix_web::{ + HttpRequest, HttpResponse, Result, + error::InternalError, + http::{StatusCode, header}, + middleware, + web::{self, Bytes, Data}, +}; +use futures::{Stream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; + +use super::{StreamableHttpServerConfig, session::SessionManager}; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ClientRequest}, + serve_server, + service::serve_directly, + transport::{ + OneshotTransport, TransportAdapterIdentity, + common::http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, + HEADER_X_ACCEL_BUFFERING, JSON_MIME_TYPE, + }, + }, +}; + +#[derive(Clone)] +pub struct StreamableHttpService { + pub config: StreamableHttpServerConfig, + session_manager: Arc, + service_factory: Arc Result + Send + Sync>, +} + +impl StreamableHttpService +where + S: crate::Service + Send + 'static, + M: SessionManager + 'static, +{ + pub fn new( + service_factory: impl Fn() -> Result + Send + Sync + 'static, + session_manager: Arc, + config: StreamableHttpServerConfig, + ) -> Self { + Self { + config, + session_manager, + service_factory: Arc::new(service_factory), + } + } + + fn get_service(&self) -> Result { + (self.service_factory)() + } + + /// Configure actix_web routes for the streamable HTTP server + pub fn configure(service: Arc) -> impl FnOnce(&mut web::ServiceConfig) { + move |cfg: &mut web::ServiceConfig| { + cfg.service( + web::scope("/") + .app_data(Data::new(service.clone())) + .wrap(middleware::NormalizePath::trim()) + .route("", web::get().to(Self::handle_get)) + .route("", web::post().to(Self::handle_post)) + .route("", web::delete().to(Self::handle_delete)), + ); + } + } + + async fn handle_get( + req: HttpRequest, + service: Data>>, + ) -> Result { + // Check accept header + let accept = req + .headers() + .get(header::ACCEPT) + .and_then(|h| h.to_str().ok()); + + if !accept.is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) { + return Ok(HttpResponse::NotAcceptable() + .body("Not Acceptable: Client must accept text/event-stream")); + } + + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + + let Some(session_id) = session_id else { + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required")); + }; + + tracing::debug!(%session_id, "GET request for SSE stream"); + + // Check if session exists + let has_session = service + .session_manager + .has_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + if !has_session { + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found")); + } + + // Check if last event id is provided + let last_event_id = req + .headers() + .get(HEADER_LAST_EVENT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + + // Get the appropriate stream + let sse_stream: std::pin::Pin + Send>> = + if let Some(last_event_id) = last_event_id { + tracing::debug!(%session_id, %last_event_id, "Resuming stream from last event"); + Box::pin( + service + .session_manager + .resume(&session_id, last_event_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?, + ) + } else { + tracing::debug!(%session_id, "Creating standalone stream"); + Box::pin( + service + .session_manager + .create_standalone_stream(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?, + ) + }; + + // Convert to SSE format + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = sse_stream; + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(msg) = stream.next() => { + let data = serde_json::to_string(&msg.message) + .unwrap_or_else(|_| "{}".to_string()); + let mut output = String::new(); + if let Some(id) = msg.event_id { + output.push_str(&format!("id: {}\n", id)); + } + output.push_str(&format!("data: {}\n\n", data)); + yield Ok::<_, actix_web::Error>(Bytes::from(output)); + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type("text/event-stream") + .append_header(("Cache-Control", "no-cache")) + .append_header(("X-Accel-Buffering", "no")) + .streaming(sse_stream)) + } + + async fn handle_post( + req: HttpRequest, + body: Bytes, + service: Data>>, + ) -> Result { + // Check accept header + let accept = req + .headers() + .get(header::ACCEPT) + .and_then(|h| h.to_str().ok()); + + if !accept.is_some_and(|header| { + header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE) + }) { + return Ok(HttpResponse::NotAcceptable().body( + "Not Acceptable: Client must accept both application/json and text/event-stream", + )); + } + + // Check content type + let content_type = req + .headers() + .get(header::CONTENT_TYPE) + .and_then(|h| h.to_str().ok()); + + if !content_type.is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) { + return Ok(HttpResponse::UnsupportedMediaType() + .body("Unsupported Media Type: Content-Type must be application/json")); + } + + // Deserialize the message + let mut message: ClientJsonRpcMessage = serde_json::from_slice(&body) + .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))?; + + tracing::debug!(?message, "POST request with message"); + + if service.config.stateful_mode { + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()); + + if let Some(session_id) = session_id { + let session_id = session_id.to_owned().into(); + tracing::debug!(%session_id, "POST request with existing session"); + + let has_session = service + .session_manager + .has_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + if !has_session { + tracing::warn!(%session_id, "Session not found"); + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session not found")); + } + + // Note: In actix-web we can't inject request parts like in tower, + // but session_id is already available through headers + + match message { + ClientJsonRpcMessage::Request(_) => { + let stream = service + .session_manager + .create_stream(&session_id, message) + .await + .map_err(|e| { + InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR) + })?; + + // Convert to SSE format + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = Box::pin(stream); + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(msg) = stream.next() => { + let data = serde_json::to_string(&msg.message) + .unwrap_or_else(|_| "{}".to_string()); + let mut output = String::new(); + if let Some(id) = msg.event_id { + output.push_str(&format!("id: {}\n", id)); + } + output.push_str(&format!("data: {}\n\n", data)); + yield Ok::<_, actix_web::Error>(Bytes::from(output)); + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .streaming(sse_stream)) + } + ClientJsonRpcMessage::Notification(_) + | ClientJsonRpcMessage::Response(_) + | ClientJsonRpcMessage::Error(_) => { + // Handle notification + service + .session_manager + .accept_message(&session_id, message) + .await + .map_err(|e| { + InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR) + })?; + + Ok(HttpResponse::Accepted().finish()) + } + ClientJsonRpcMessage::BatchRequest(_) + | ClientJsonRpcMessage::BatchResponse(_) => { + Ok(HttpResponse::NotImplemented() + .body("Batch requests are not supported yet")) + } + } + } else { + // No session id in stateful mode - create new session + tracing::debug!("POST request without session, creating new session"); + + let (session_id, transport) = service + .session_manager + .create_session() + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + tracing::info!(%session_id, "Created new session"); + + if let ClientJsonRpcMessage::Request(req) = &mut message { + if !matches!(req.request, ClientRequest::InitializeRequest(_)) { + return Ok( + HttpResponse::UnprocessableEntity().body("Expected initialize request") + ); + } + } else { + return Ok( + HttpResponse::UnprocessableEntity().body("Expected initialize request") + ); + } + + let service_instance = service + .get_service() + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + // Spawn a task to serve the session + tokio::spawn({ + let session_manager = service.session_manager.clone(); + let session_id = session_id.clone(); + async move { + let service = serve_server::( + service_instance, + transport, + ) + .await; + match service { + Ok(service) => { + let _ = service.waiting().await; + } + Err(e) => { + tracing::error!("Failed to create service: {e}"); + } + } + let _ = session_manager + .close_session(&session_id) + .await + .inspect_err(|e| { + tracing::error!("Failed to close session {session_id}: {e}"); + }); + } + }); + + // Get initialize response + let response = service + .session_manager + .initialize_session(&session_id, message) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + // Return SSE stream with single response + let sse_stream = async_stream::stream! { + yield Ok::<_, actix_web::Error>(Bytes::from(format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string()) + ))); + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .append_header((HEADER_SESSION_ID, session_id.as_ref())) + .streaming(sse_stream)) + } + } else { + // Stateless mode + tracing::debug!("POST request in stateless mode"); + + match message { + ClientJsonRpcMessage::Request(request) => { + tracing::debug!(?request, "Processing request in stateless mode"); + + // In stateless mode, handle the request directly + let service_instance = service + .get_service() + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + let (transport, receiver) = + OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); + let service_handle = serve_directly(service_instance, transport, None); + + tokio::spawn(async move { + // Let the service process the request + let _ = service_handle.waiting().await; + }); + + // Convert receiver stream to SSE format + let sse_stream = ReceiverStream::new(receiver).map(|message| { + tracing::info!(?message); + let data = + serde_json::to_string(&message).unwrap_or_else(|_| "{}".to_string()); + Ok::<_, actix_web::Error>(Bytes::from(format!("data: {}\n\n", data))) + }); + + // Add keep-alive if configured + let keep_alive = service.config.sse_keep_alive; + let sse_stream = async_stream::stream! { + let mut stream = Box::pin(sse_stream); + let mut keep_alive_timer = keep_alive.map(|duration| tokio::time::interval(duration)); + + loop { + tokio::select! { + Some(result) = stream.next() => { + match result { + Ok(data) => yield Ok(data), + Err(e) => yield Err(e), + } + } + _ = async { + match keep_alive_timer.as_mut() { + Some(timer) => { + timer.tick().await; + } + None => { + std::future::pending::<()>().await; + } + } + } => { + yield Ok(Bytes::from(":ping\n\n")); + } + else => break, + } + } + }; + + Ok(HttpResponse::Ok() + .content_type(EVENT_STREAM_MIME_TYPE) + .append_header((header::CACHE_CONTROL, "no-cache")) + .append_header((HEADER_X_ACCEL_BUFFERING, "no")) + .streaming(sse_stream)) + } + _ => Ok(HttpResponse::UnprocessableEntity().body("Unexpected message type")), + } + } + } + + async fn handle_delete( + req: HttpRequest, + service: Data>>, + ) -> Result { + // Check session id + let session_id = req + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + + let Some(session_id) = session_id else { + return Ok(HttpResponse::Unauthorized().body("Unauthorized: Session ID is required")); + }; + + tracing::debug!(%session_id, "DELETE request to close session"); + + // Close session + service + .session_manager + .close_session(&session_id) + .await + .map_err(|e| InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR))?; + + tracing::info!(%session_id, "Session closed"); + + Ok(HttpResponse::NoContent().finish()) + } +} diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs similarity index 96% rename from crates/rmcp/src/transport/streamable_http_server/tower.rs rename to crates/rmcp/src/transport/streamable_http_server/axum.rs index 0ed0858e..87c9c083 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/axum.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; +use std::{convert::Infallible, fmt::Display, sync::Arc}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; @@ -7,7 +7,7 @@ use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; -use super::session::SessionManager; +use super::{StreamableHttpServerConfig, session::SessionManager}; use crate::{ RoleServer, model::{ClientJsonRpcMessage, ClientRequest, GetExtensions}, @@ -27,23 +27,6 @@ use crate::{ }, }; -#[derive(Debug, Clone)] -pub struct StreamableHttpServerConfig { - /// The ping message duration for SSE connections. - pub sse_keep_alive: Option, - /// If true, the server will create a session for each request and keep it alive. - pub stateful_mode: bool, -} - -impl Default for StreamableHttpServerConfig { - fn default() -> Self { - Self { - sse_keep_alive: Some(Duration::from_secs(15)), - stateful_mode: true, - } - } -} - pub struct StreamableHttpService { pub config: StreamableHttpServerConfig, session_manager: Arc, diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index 4f4fccee..f939654f 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -3,7 +3,7 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -51,6 +51,7 @@ impl Calculator { } } +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/crates/rmcp/tests/test_sse_server.rs b/crates/rmcp/tests/test_sse_server.rs new file mode 100644 index 00000000..1f1c79bf --- /dev/null +++ b/crates/rmcp/tests/test_sse_server.rs @@ -0,0 +1,227 @@ +#![cfg(feature = "transport-sse-server")] + +// Import framework-specific types +#[cfg(feature = "actix-web")] +use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; +use rmcp::{ServiceExt, transport::sse_server::SseServerConfig}; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +mod common; +use common::calculator::Calculator; + +async fn init() { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); +} + +// Common test logic for basic SSE server test +async fn test_sse_server_basic_common( + bind_addr: std::net::SocketAddr, + ct: CancellationToken, + service_ct: CancellationToken, +) -> anyhow::Result<()> { + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test that server is running by making a request + let client = reqwest::Client::new(); + let response = client + .get(format!("http://{}/sse", bind_addr)) + .header("Accept", "text/event-stream") + .send() + .await?; + + // SSE endpoint should return OK and start streaming + assert_eq!(response.status(), reqwest::StatusCode::OK); + + ct.cancel(); + service_ct.cancel(); + Ok(()) +} + +#[cfg(all(feature = "transport-sse-server", feature = "axum"))] +#[tokio::test] +async fn test_axum_sse_server_basic() -> anyhow::Result<()> { + init().await; + + let config = SseServerConfig { + bind: "127.0.0.1:0".parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let sse_server = AxumSseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let service_ct = sse_server.with_service(Calculator::default); + + test_sse_server_basic_common(bind_addr, ct, service_ct).await +} + +#[cfg(all(feature = "transport-sse-server", feature = "actix-web"))] +#[actix_web::test] +async fn test_actix_sse_server_basic() -> anyhow::Result<()> { + init().await; + + let config = SseServerConfig { + bind: "127.0.0.1:0".parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct = config.ct.clone(); + let sse_server = ActixSseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let service_ct = sse_server.with_service(Calculator::default); + + test_sse_server_basic_common(bind_addr, ct, service_ct).await +} + +// Common client-server integration test logic +#[cfg(feature = "transport-sse-client")] +async fn test_client_server_integration_common( + actual_addr: std::net::SocketAddr, + ct: CancellationToken, +) -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + let transport = SseClientTransport::start(format!("http://{}/sse", actual_addr)).await?; + let client = ().serve(transport).await?; + + // Test basic operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + client.cancel().await?; + ct.cancel(); + Ok(()) +} + +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "axum" +))] +#[tokio::test] +async fn test_axum_client_server_integration() -> anyhow::Result<()> { + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + test_client_server_integration_common(actual_addr, ct).await +} + +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "actix-web" +))] +#[actix_web::test] +async fn test_actix_client_server_integration() -> anyhow::Result<()> { + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + test_client_server_integration_common(actual_addr, ct).await +} + +// Common concurrent clients test logic +#[cfg(feature = "transport-sse-client")] +async fn test_concurrent_clients_common( + actual_addr: std::net::SocketAddr, + ct: CancellationToken, +) -> anyhow::Result<()> { + use rmcp::transport::SseClientTransport; + + const NUM_CLIENTS: usize = 5; + let mut handles = vec![]; + + for i in 0..NUM_CLIENTS { + let addr = actual_addr; + let handle = tokio::spawn(async move { + let transport = SseClientTransport::start(format!("http://{}/sse", addr)).await?; + let client = ().serve(transport).await?; + + // Each client does some operations + let tools = client.list_all_tools().await?; + assert!(!tools.is_empty()); + assert_eq!(tools.len(), 2); // sum and sub + + tracing::info!("Client {} completed operations", i); + client.cancel().await?; + Ok::<(), anyhow::Error>(()) + }); + handles.push(handle); + } + + // Wait for all clients to complete + for handle in handles { + handle.await??; + } + + ct.cancel(); + Ok(()) +} + +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "axum" +))] +#[tokio::test] +async fn test_axum_concurrent_clients() -> anyhow::Result<()> { + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = AxumSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + test_concurrent_clients_common(actual_addr, ct).await +} + +#[cfg(all( + feature = "transport-sse-server", + feature = "transport-sse-client", + feature = "actix-web" +))] +#[actix_web::test] +async fn test_actix_concurrent_clients() -> anyhow::Result<()> { + init().await; + + const BIND_ADDRESS: &str = "127.0.0.1:0"; + + let sse_server = ActixSseServer::serve(BIND_ADDRESS.parse()?).await?; + let actual_addr = sse_server.config.bind; + let ct = sse_server.with_service(Calculator::default); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + test_concurrent_clients_common(actual_addr, ct).await +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 3f2761cd..6e22f9d7 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -1,12 +1,14 @@ +// Import framework-specific types +#[cfg(feature = "actix-web")] +use rmcp::transport::streamable_http_server::actix_web::StreamableHttpService as ActixStreamableHttpService; +#[cfg(feature = "axum")] +use rmcp::transport::streamable_http_server::axum::StreamableHttpService as AxumStreamableHttpService; use rmcp::{ ServiceExt, service::QuitReason, transport::{ ConfigureCommandExt, SseServer, StreamableHttpClientTransport, StreamableHttpServerConfig, - TokioChildProcess, - streamable_http_server::{ - session::local::LocalSessionManager, tower::StreamableHttpService, - }, + TokioChildProcess, streamable_http_server::session::local::LocalSessionManager, }, }; use tokio_util::sync::CancellationToken; @@ -17,6 +19,8 @@ use common::calculator::Calculator; const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; const STREAMABLE_HTTP_JS_BIND_ADDRESS: &str = "127.0.0.1:8002"; +#[cfg(feature = "actix-web")] +const STREAMABLE_HTTP_ACTIX_BIND_ADDRESS: &str = "127.0.0.1:8004"; #[tokio::test] async fn test_with_js_client() -> anyhow::Result<()> { @@ -78,8 +82,9 @@ async fn test_with_js_server() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "axum")] #[tokio::test] -async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { +async fn test_with_js_streamable_http_client_axum() -> anyhow::Result<()> { let _ = tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -94,8 +99,8 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .wait() .await?; - let service: StreamableHttpService = - StreamableHttpService::new( + let service: AxumStreamableHttpService = + AxumStreamableHttpService::new( || Ok(Calculator::new()), Default::default(), StreamableHttpServerConfig { @@ -125,6 +130,67 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "actix-web")] +#[actix_web::test] +async fn test_with_js_streamable_http_client_actix() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let service = std::sync::Arc::new( + ActixStreamableHttpService::::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + }, + ), + ); + + let server = actix_web::HttpServer::new(move || { + actix_web::App::new() + .wrap(actix_web::middleware::Logger::default()) + .service( + actix_web::web::scope("/mcp") + .configure(ActixStreamableHttpService::configure(service.clone())), + ) + }) + .bind(STREAMABLE_HTTP_ACTIX_BIND_ADDRESS)? + .run(); + + let server_handle = server.handle(); + let server_task = tokio::spawn(server); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + let exit_status = tokio::process::Command::new("node") + .arg("tests/test_with_js/streamable_client.js") + .arg(format!( + "http://{}/mcp/", + STREAMABLE_HTTP_ACTIX_BIND_ADDRESS + )) + .spawn()? + .wait() + .await?; + assert!(exit_status.success()); + + server_handle.stop(true).await; + let _ = server_task.await; + Ok(()) +} + #[tokio::test] async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { let _ = tracing_subscriber::registry() diff --git a/crates/rmcp/tests/test_with_js/streamable_client.js b/crates/rmcp/tests/test_with_js/streamable_client.js index 99826131..fb86b7ac 100644 --- a/crates/rmcp/tests/test_with_js/streamable_client.js +++ b/crates/rmcp/tests/test_with_js/streamable_client.js @@ -1,7 +1,8 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/mcp/`)); +const url = process.argv[2] || "http://127.0.0.1:8001/mcp/"; +const transport = new StreamableHTTPClientTransport(new URL(url)); const client = new Client( { diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 9def1722..58a84eb3 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,7 +1,11 @@ -use axum::Router; +// Import framework-specific types +#[cfg(feature = "actix-web")] +use rmcp::transport::sse_server::actix_web::SseServer as ActixSseServer; +#[cfg(feature = "axum")] +use rmcp::transport::sse_server::axum::SseServer as AxumSseServer; use rmcp::{ ServiceExt, - transport::{ConfigureCommandExt, SseServer, TokioChildProcess, sse_server::SseServerConfig}, + transport::{ConfigureCommandExt, TokioChildProcess, sse_server::SseServerConfig}, }; use tokio::time::timeout; use tokio_util::sync::CancellationToken; @@ -26,20 +30,15 @@ async fn init() -> anyhow::Result<()> { Ok(()) } -#[tokio::test] -async fn test_with_python_client() -> anyhow::Result<()> { - init().await?; - - const BIND_ADDRESS: &str = "127.0.0.1:8000"; - - let ct = SseServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Calculator::default); - +// Common test logic for Python client +async fn test_with_python_client_common( + bind_address: &str, + ct: CancellationToken, +) -> anyhow::Result<()> { let status = tokio::process::Command::new("uv") .arg("run") .arg("client.py") - .arg(format!("http://{BIND_ADDRESS}/sse")) + .arg(format!("http://{bind_address}/sse")) .current_dir("tests/test_with_python") .spawn()? .wait() @@ -49,9 +48,40 @@ async fn test_with_python_client() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "axum")] +#[tokio::test] +async fn test_with_python_client_axum() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8000"; + + let ct = AxumSseServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + test_with_python_client_common(BIND_ADDRESS, ct).await +} + +#[cfg(feature = "actix-web")] +#[actix_web::test] +async fn test_with_python_client_actix() -> anyhow::Result<()> { + init().await?; + + const BIND_ADDRESS: &str = "127.0.0.1:8000"; + + let ct = ActixSseServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + test_with_python_client_common(BIND_ADDRESS, ct).await +} + /// Test the SSE server in a nested Axum router. +#[cfg(feature = "axum")] #[tokio::test] async fn test_nested_with_python_client() -> anyhow::Result<()> { + use axum::Router; + init().await?; const BIND_ADDRESS: &str = "127.0.0.1:8001"; @@ -67,7 +97,7 @@ async fn test_nested_with_python_client() -> anyhow::Result<()> { let listener = tokio::net::TcpListener::bind(&sse_config.bind).await?; - let (sse_server, sse_router) = SseServer::new(sse_config); + let (sse_server, sse_router) = AxumSseServer::new(sse_config); let ct = sse_server.with_service(Calculator::default); let main_router = Router::new().nest("/nested", sse_router); diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index a6acd646..8341b70a 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -33,6 +33,8 @@ tracing-subscriber = { version = "0.3", features = [ futures = "0.3" rand = { version = "0.9", features = ["std"] } axum = { version = "0.8", features = ["macros"] } +actix-web = { version = "4", optional = true } +actix-rt = { version = "2", optional = true } schemars = { version = "0.8", optional = true } reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" @@ -82,3 +84,16 @@ path = "src/counter_hyper_streamable_http.rs" [[example]] name = "servers_sampling_stdio" path = "src/sampling_stdio.rs" + +[[example]] +name = "servers_counter_sse_actix" +path = "src/counter_sse_actix.rs" +required-features = ["actix-web"] + +[[example]] +name = "servers_counter_streamable_http_actix" +path = "src/counter_streamable_http_actix.rs" +required-features = ["actix-web"] + +[features] +actix-web = ["dep:actix-web", "dep:actix-rt", "rmcp/actix-web"] diff --git a/examples/servers/src/complex_auth_sse.rs b/examples/servers/src/complex_auth_sse.rs index 75488657..a713fda8 100644 --- a/examples/servers/src/complex_auth_sse.rs +++ b/examples/servers/src/complex_auth_sse.rs @@ -13,12 +13,11 @@ use axum::{ }; use rand::{Rng, distr::Alphanumeric}; use rmcp::transport::{ - SseServer, auth::{ AuthorizationMetadata, ClientRegistrationRequest, ClientRegistrationResponse, OAuthClientConfig, }, - sse_server::SseServerConfig, + sse_server::{SseServerConfig, axum::SseServer}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs index 6312180d..7c1a535a 100644 --- a/examples/servers/src/counter_hyper_streamable_http.rs +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -1,3 +1,4 @@ +// Example of using streamable HTTP server transport with hyper/tower mod common; use common::counter::Counter; use hyper_util::{ @@ -6,7 +7,7 @@ use hyper_util::{ service::TowerToHyperService, }; use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, + axum::StreamableHttpService, session::local::LocalSessionManager, }; #[tokio::main] diff --git a/examples/servers/src/counter_sse.rs b/examples/servers/src/counter_sse.rs index 373a8a6d..a62d6ec2 100644 --- a/examples/servers/src/counter_sse.rs +++ b/examples/servers/src/counter_sse.rs @@ -1,4 +1,6 @@ -use rmcp::transport::sse_server::{SseServer, SseServerConfig}; +// Example of using SSE server transport with axum framework +// This requires the "axum" feature to be enabled in Cargo.toml +use rmcp::transport::sse_server::{SseServerConfig, axum::SseServer}; use tracing_subscriber::{ layer::SubscriberExt, util::SubscriberInitExt, @@ -27,28 +29,27 @@ async fn main() -> anyhow::Result<()> { sse_keep_alive: None, }; - let (sse_server, router) = SseServer::new(config); + let ct_signal = config.ct.clone(); - // Do something with the router, e.g., add routes or middleware - - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; - - let ct = sse_server.config.ct.child_token(); + // When axum feature is enabled, use axum-specific SseServer + let sse_server = SseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let ct = sse_server.with_service(Counter::new); - let server = axum::serve(listener, router).with_graceful_shutdown(async move { - ct.cancelled().await; - tracing::info!("sse server cancelled"); - }); + println!("\nšŸš€ SSE Server (axum) running at http://{}", bind_addr); + println!("šŸ“” SSE endpoint: http://{}/sse", bind_addr); + println!("šŸ“® Message endpoint: http://{}/message", bind_addr); + println!("\nPress Ctrl+C to stop the server\n"); + // Set up Ctrl-C handler tokio::spawn(async move { - if let Err(e) = server.await { - tracing::error!(error = %e, "sse server shutdown with error"); - } + tokio::signal::ctrl_c().await.ok(); + println!("\nā¹ļø Shutting down..."); + ct_signal.cancel(); }); - let ct = sse_server.with_service(Counter::new); - - tokio::signal::ctrl_c().await?; - ct.cancel(); + // Wait for cancellation + ct.cancelled().await; + println!("āœ… Server stopped"); Ok(()) } diff --git a/examples/servers/src/counter_sse_actix.rs b/examples/servers/src/counter_sse_actix.rs new file mode 100644 index 00000000..622fc49f --- /dev/null +++ b/examples/servers/src/counter_sse_actix.rs @@ -0,0 +1,61 @@ +// Example of using SSE server transport with actix-web framework +// This requires the "actix-web" feature to be enabled in Cargo.toml +use rmcp::transport::sse_server::{SseServer, SseServerConfig}; +use tracing_subscriber::{ + layer::SubscriberExt, + util::SubscriberInitExt, + {self}, +}; +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +// Note: Using #[actix_web::main] instead of #[tokio::main] +// This sets up the actix-web runtime which is required for actix-web transports +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let config = SseServerConfig { + bind: BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: tokio_util::sync::CancellationToken::new(), + sse_keep_alive: None, + }; + + let ct_signal = config.ct.clone(); + + // When actix-web feature is enabled, SseServer uses actix-web implementation + // The same API works with both axum and actix-web + let sse_server = SseServer::serve_with_config(config).await?; + let bind_addr = sse_server.config.bind; + let ct = sse_server.with_service(Counter::new); + + println!( + "\nšŸš€ SSE Server (actix-web) running at http://{}", + bind_addr + ); + println!("šŸ“” SSE endpoint: http://{}/sse", bind_addr); + println!("šŸ“® Message endpoint: http://{}/message", bind_addr); + println!("\nPress Ctrl+C to stop the server\n"); + + // Set up Ctrl-C handler + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + println!("\nā¹ļø Shutting down..."); + ct_signal.cancel(); + }); + + // Wait for cancellation + ct.cancelled().await; + println!("āœ… Server stopped"); + Ok(()) +} diff --git a/examples/servers/src/counter_streamable_http_actix.rs b/examples/servers/src/counter_streamable_http_actix.rs new file mode 100644 index 00000000..fb7afa03 --- /dev/null +++ b/examples/servers/src/counter_streamable_http_actix.rs @@ -0,0 +1,48 @@ +// Example of using streamable HTTP server transport with actix-web framework +// This requires the "actix-web" feature to be enabled in Cargo.toml +mod common; +use std::sync::Arc; + +use actix_web::{App, HttpServer, middleware}; +use common::counter::Counter; +use rmcp::transport::streamable_http_server::{ + StreamableHttpService, session::local::LocalSessionManager, +}; + +// Note: Using #[actix_web::main] instead of #[tokio::main] +// This sets up the actix-web runtime which is required for actix-web transports +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + let bind_addr = "127.0.0.1:8080"; + + // Create the streamable HTTP service + // When actix-web feature is enabled, StreamableHttpService uses actix-web implementation + let service = Arc::new(StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + Default::default(), + )); + + println!("Starting actix-web streamable HTTP server on {}", bind_addr); + println!("POST / - Send JSON-RPC requests"); + println!("GET / - Resume SSE stream with session ID"); + println!("DELETE / - Close session"); + + // Use actix-web's HttpServer and App to host the service + // The StreamableHttpService::configure method sets up the routes + HttpServer::new(move || { + App::new() + .wrap(middleware::Logger::default()) + .configure(StreamableHttpService::configure(service.clone())) + }) + .bind(bind_addr)? + .run() + .await?; + + Ok(()) +} diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index ff00cec6..b4d8bd65 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -1,5 +1,6 @@ +// Example of using streamable HTTP server transport with axum framework use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, + axum::StreamableHttpService, session::local::LocalSessionManager, }; use tracing_subscriber::{ layer::SubscriberExt, diff --git a/examples/servers/src/simple_auth_sse.rs b/examples/servers/src/simple_auth_sse.rs index 6514c161..2539a9a3 100644 --- a/examples/servers/src/simple_auth_sse.rs +++ b/examples/servers/src/simple_auth_sse.rs @@ -16,7 +16,7 @@ use axum::{ response::{Html, Response}, routing::get, }; -use rmcp::transport::{SseServer, sse_server::SseServerConfig}; +use rmcp::transport::sse_server::{SseServerConfig, axum::SseServer}; use tokio_util::sync::CancellationToken; mod common; use common::counter::Counter;