From b37ce396629a77c7a9207b8b5a433e22f1742831 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 28 Jun 2025 06:59:51 +0000 Subject: [PATCH] feat(guard): support streaming responses --- Cargo.lock | 6 +- packages/edge/infra/guard/core/Cargo.toml | 2 + .../infra/guard/core/src/proxy_service.rs | 131 +++++++-- .../edge/infra/guard/core/tests/common/mod.rs | 5 +- packages/edge/infra/guard/core/tests/proxy.rs | 22 +- .../core/tests/streaming_response_test.rs | 252 ++++++++++++++++++ 6 files changed, 381 insertions(+), 37 deletions(-) create mode 100644 packages/edge/infra/guard/core/tests/streaming_response_test.rs diff --git a/Cargo.lock b/Cargo.lock index 3619af1e5a..c944dd497d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,7 +803,6 @@ dependencies = [ "chirp-client", "chirp-workflow", "chrono", - "cluster", "faker-cdn-site", "faker-game", "faker-game-namespace", @@ -12385,9 +12384,11 @@ dependencies = [ "rustls 0.23.25", "rustls-pemfile 2.2.0", "serde_json", + "service-discovery", "tokio", "tracing", "types-proto", + "url", "uuid", ] @@ -12401,6 +12402,7 @@ dependencies = [ "futures", "futures-util", "global-error", + "http-body 1.0.1", "http-body-util", "hyper 1.6.0", "hyper-tungstenite", @@ -12420,6 +12422,7 @@ dependencies = [ "serde_json", "tokio", "tokio-rustls 0.26.2", + "tokio-stream", "tokio-tungstenite 0.26.2", "tracing", "tracing-subscriber", @@ -13906,6 +13909,7 @@ version = "25.4.2" dependencies = [ "rand 0.8.5", "reqwest 0.12.12", + "rivet-api", "serde", "tokio", "tracing", diff --git a/packages/edge/infra/guard/core/Cargo.toml b/packages/edge/infra/guard/core/Cargo.toml index b065aff2e3..e45928d718 100644 --- a/packages/edge/infra/guard/core/Cargo.toml +++ b/packages/edge/infra/guard/core/Cargo.toml @@ -9,6 +9,7 @@ license.workspace = true global-error.workspace = true bytes = "1.6.0" futures = "0.3.30" +http-body = "1.0.0" http-body-util = "0.1.1" hyper = { version = "1.6.0", features = ["full", "http1", "http2"] } hyper-util = { version = "0.1.10", features = ["full"] } @@ -39,3 +40,4 @@ clickhouse-inserter.workspace = true futures-util = "0.3.30" futures = "0.3.30" reqwest = { version = "0.11.27", features = ["native-tls"] } +tokio-stream = "0.1.15" diff --git a/packages/edge/infra/guard/core/src/proxy_service.rs b/packages/edge/infra/guard/core/src/proxy_service.rs index d025a2dbbd..e381dd1b91 100644 --- a/packages/edge/infra/guard/core/src/proxy_service.rs +++ b/packages/edge/infra/guard/core/src/proxy_service.rs @@ -13,7 +13,7 @@ use tokio::sync::Mutex; use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use global_error::*; -use http_body_util::Full; +use http_body_util::{BodyExt, Full}; use hyper::body::Incoming as BodyIncoming; use hyper::header::HeaderName; use hyper::{Request, Response, StatusCode}; @@ -34,6 +34,68 @@ const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour +/// Response body type that can handle both streaming and buffered responses +#[derive(Debug)] +pub enum ResponseBody { + /// Buffered response body + Full(Full), + /// Streaming response body + Incoming(BodyIncoming), +} + +impl http_body::Body for ResponseBody { + type Data = Bytes; + type Error = Box; + + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + ResponseBody::Full(body) => { + let pin = std::pin::Pin::new(body); + match pin.poll_frame(cx) { + std::task::Poll::Ready(Some(Ok(frame))) => { + std::task::Poll::Ready(Some(Ok(frame))) + } + std::task::Poll::Ready(Some(Err(e))) => { + std::task::Poll::Ready(Some(Err(Box::new(e)))) + } + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } + ResponseBody::Incoming(body) => { + let pin = std::pin::Pin::new(body); + match pin.poll_frame(cx) { + std::task::Poll::Ready(Some(Ok(frame))) => { + std::task::Poll::Ready(Some(Ok(frame))) + } + std::task::Poll::Ready(Some(Err(e))) => { + std::task::Poll::Ready(Some(Err(Box::new(e)))) + } + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } + } + } + + fn is_end_stream(&self) -> bool { + match self { + ResponseBody::Full(body) => body.is_end_stream(), + ResponseBody::Incoming(body) => body.is_end_stream(), + } + } + + fn size_hint(&self) -> http_body::SizeHint { + match self { + ResponseBody::Full(body) => body.size_hint(), + ResponseBody::Incoming(body) => body.size_hint(), + } + } +} + // Routing types #[derive(Clone, Debug)] pub struct RouteTarget { @@ -71,7 +133,7 @@ pub struct StructuredResponse { } impl StructuredResponse { - pub fn build_response(&self) -> GlobalResult>> { + pub fn build_response(&self) -> GlobalResult> { let mut body = StdHashMap::new(); body.insert("message", self.message.clone().into_owned()); @@ -85,7 +147,7 @@ impl StructuredResponse { let response = Response::builder() .status(self.status) .header(hyper::header::CONTENT_TYPE, "application/json") - .body(Full::new(bytes))?; + .body(ResponseBody::Full(Full::new(bytes)))?; Ok(response) } @@ -605,7 +667,7 @@ impl ProxyService { &self, req: Request, request_context: &mut RequestContext, - ) -> GlobalResult>> { + ) -> GlobalResult> { let host = req .headers() .get(hyper::header::HOST) @@ -641,7 +703,7 @@ impl ProxyService { tracing::error!(?err, "Routing error"); return Ok(Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Full::::new(Bytes::new()))?); + .body(ResponseBody::Full(Full::::new(Bytes::new())))?); } }; @@ -669,14 +731,14 @@ impl ProxyService { let res = if !self.state.check_rate_limit(client_ip, &actor_id).await? { Response::builder() .status(StatusCode::TOO_MANY_REQUESTS) - .body(Full::::new(Bytes::new())) + .body(ResponseBody::Full(Full::::new(Bytes::new()))) .map_err(Into::into) } // Check in-flight limit else if !self.state.acquire_in_flight(client_ip, &actor_id).await? { Response::builder() .status(StatusCode::TOO_MANY_REQUESTS) - .body(Full::::new(Bytes::new())) + .body(ResponseBody::Full(Full::::new(Bytes::new()))) .map_err(Into::into) } else { // Increment metrics @@ -782,7 +844,7 @@ impl ProxyService { req: Request, mut target: RouteTarget, request_context: &mut RequestContext, - ) -> GlobalResult>> { + ) -> GlobalResult> { // Get middleware config for this actor if it exists let middleware_config = match &target.actor_id { Some(actor_id) => self.state.get_middleware_config(actor_id).await?, @@ -894,20 +956,38 @@ impl ProxyService { Ok(Ok(resp)) => { let response_receive_time = request_send_start.elapsed(); - // Convert the hyper::body::Incoming to http_body_util::Full let (parts, body) = resp.into_parts(); - // Read the response body - let body_bytes = match http_body_util::BodyExt::collect(body).await { - Ok(collected) => collected.to_bytes(), - Err(_) => Bytes::new(), - }; + // Check if this is a streaming response by examining headers + // let is_streaming = parts.headers.get("content-type") + // .and_then(|ct| ct.to_str().ok()) + // .map(|ct| ct.contains("text/event-stream") || ct.contains("application/stream")) + // .unwrap_or(false); + let is_streaming = true; + + if is_streaming { + // For streaming responses, pass through the body without buffering + tracing::debug!("Detected streaming response, preserving stream"); - // Set actual response body size in analytics - request_context.guard_response_body_bytes = Some(body_bytes.len() as u64); + // We can't easily calculate response size for streaming, so set it to None + request_context.guard_response_body_bytes = None; - let full_body = Full::new(body_bytes); - return Ok(Response::from_parts(parts, full_body)); + let streaming_body = ResponseBody::Incoming(body); + return Ok(Response::from_parts(parts, streaming_body)); + } else { + // For non-streaming responses, buffer as before + let body_bytes = match BodyExt::collect(body).await { + Ok(collected) => collected.to_bytes(), + Err(_) => Bytes::new(), + }; + + // Set actual response body size in analytics + request_context.guard_response_body_bytes = + Some(body_bytes.len() as u64); + + let full_body = ResponseBody::Full(Full::new(body_bytes)); + return Ok(Response::from_parts(parts, full_body)); + } } Ok(Err(err)) => { if !err.is_connect() || attempts >= max_attempts { @@ -944,7 +1024,9 @@ impl ProxyService { tracing::error!(?err, "Routing error"); return Ok(Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Full::::new(Bytes::new()))?); + .body(ResponseBody::Full(Full::::new( + Bytes::new(), + )))?); } }; @@ -980,7 +1062,7 @@ impl ProxyService { Ok(Response::builder() .status(status_code) - .body(Full::::new(Bytes::new()))?) + .body(ResponseBody::Full(Full::::new(Bytes::new())))?) } // Common function to build a request URI and headers @@ -1033,7 +1115,7 @@ impl ProxyService { req: Request, mut target: RouteTarget, _request_context: &mut RequestContext, - ) -> GlobalResult>> { + ) -> GlobalResult> { // Get actor and server IDs for metrics and middleware let actor_id = target.actor_id; let server_id = target.server_id; @@ -1606,7 +1688,7 @@ impl ProxyService { // Create a new response with an empty body - WebSocket upgrades don't need a body Ok(Response::from_parts( parts, - Full::::new(Bytes::new()), + ResponseBody::Full(Full::::new(Bytes::new())), )) } } @@ -1614,7 +1696,10 @@ impl ProxyService { impl ProxyService { // Process an individual request #[tracing::instrument(skip_all)] - pub async fn process(&self, req: Request) -> GlobalResult>> { + pub async fn process( + &self, + req: Request, + ) -> GlobalResult> { // Create request context for analytics tracking let mut request_context = RequestContext::new(self.state.clickhouse_inserter.clone()); diff --git a/packages/edge/infra/guard/core/tests/common/mod.rs b/packages/edge/infra/guard/core/tests/common/mod.rs index 3ca53dee3b..20f5a77436 100644 --- a/packages/edge/infra/guard/core/tests/common/mod.rs +++ b/packages/edge/infra/guard/core/tests/common/mod.rs @@ -8,7 +8,7 @@ use hyper_util::rt::TokioIo; use rivet_guard_core::{ proxy_service::{ MaxInFlightConfig, MiddlewareConfig, MiddlewareFn, MiddlewareResponse, RateLimitConfig, - RetryConfig, RouteTarget, RoutingFn, RoutingResponse, RoutingResult, RoutingTimeout, + RetryConfig, RouteConfig, RouteTarget, RoutingFn, RoutingOutput, RoutingTimeout, TimeoutConfig, }, GlobalErrorWrapper, @@ -445,7 +445,7 @@ pub fn create_test_routing_fn(test_server: &TestServer) -> RoutingFn { path: path.to_string(), }; - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![target], timeout: RoutingTimeout { routing_timeout: 5, // 5 seconds for routing timeout @@ -552,6 +552,7 @@ pub async fn start_guard_with_middleware( routing_fn_clone, middleware_fn_clone, rivet_guard_core::proxy_service::PortType::Http, // Default port type for tests + None, // No ClickHouse inserter for tests )); // Run the server until shutdown signal diff --git a/packages/edge/infra/guard/core/tests/proxy.rs b/packages/edge/infra/guard/core/tests/proxy.rs index 4fd11a7296..12ad9b5136 100644 --- a/packages/edge/infra/guard/core/tests/proxy.rs +++ b/packages/edge/infra/guard/core/tests/proxy.rs @@ -14,7 +14,7 @@ use common::{ }; use rivet_guard_core::proxy_service::{ MaxInFlightConfig, MiddlewareConfig, MiddlewareResponse, RateLimitConfig, RetryConfig, - RouteConfig, RouteTarget, RoutingResponse, RoutingTimeout, TimeoutConfig, + RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout, TimeoutConfig, }; #[tokio::test] @@ -129,12 +129,12 @@ async fn test_rate_limiting() { let route_target = RouteTarget { actor_id: Some(actor_id), server_id: Some(server_id), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }; - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![route_target], timeout: RoutingTimeout { routing_timeout: 5 }, })) @@ -200,11 +200,11 @@ async fn test_max_in_flight_requests() { path: &str, _port_type: rivet_guard_core::proxy_service::PortType| { Box::pin(async move { - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { actor_id: Some(actor_id), server_id: Some(server_id), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }], @@ -266,11 +266,11 @@ async fn test_timeout_handling() { path: &str, _port_type: rivet_guard_core::proxy_service::PortType| { Box::pin(async move { - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { actor_id: Some(actor_id), server_id: Some(server_id), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }], @@ -327,11 +327,11 @@ async fn test_retry_functionality() { path: &str, _port_type: rivet_guard_core::proxy_service::PortType| { Box::pin(async move { - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { actor_id: Some(Uuid::new_v4()), server_id: Some(Uuid::new_v4()), - host: server_addr.ip(), + host: server_addr.ip().to_string(), port: server_addr.port(), path: path.to_string(), }], @@ -504,11 +504,11 @@ async fn test_different_path_routing() { Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap() }; - Ok(RoutingResponse::Ok(RouteConfig { + Ok(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { actor_id: Some(actor_id), server_id: Some(Uuid::new_v4()), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }], diff --git a/packages/edge/infra/guard/core/tests/streaming_response_test.rs b/packages/edge/infra/guard/core/tests/streaming_response_test.rs new file mode 100644 index 0000000000..c843ea6368 --- /dev/null +++ b/packages/edge/infra/guard/core/tests/streaming_response_test.rs @@ -0,0 +1,252 @@ +mod common; + +use bytes::Bytes; +use http_body_util::{Full, BodyExt}; +use hyper::service::service_fn; +use hyper::{body::Incoming, Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::TcpListener; +use tokio::sync::mpsc; +use futures_util::StreamExt; + +use common::{ + create_test_config, start_guard, init_tracing, +}; +use rivet_guard_core::proxy_service::{ + RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout, RoutingFn, +}; +use uuid::Uuid; + +#[tokio::test] +async fn test_streaming_response_should_timeout() { + // This test should demonstrate that streaming responses are broken + // The test will timeout because the proxy buffers the entire response + // before returning it, which never happens for a streaming endpoint + + init_tracing(); + + println!("Starting streaming test server..."); + let (server_addr, message_sender) = start_streaming_server().await; + println!("Streaming server started at: {}", server_addr); + + // Create a routing function that routes to our streaming server + let routing_fn = create_streaming_routing_fn(server_addr); + + // Start guard proxy with the routing function + let config = create_test_config(|_| {}); + let (guard_addr, _shutdown) = start_guard(config, routing_fn).await; + println!("Guard proxy started at: {}", guard_addr); + + // Set up a test timeout - this should be shorter than what we expect + // the proxy would take to buffer an infinite stream + let test_timeout = Duration::from_secs(3); + + // Create an HTTP client to make requests to the guard proxy (not directly to our server) + let client = hyper_util::client::legacy::Client::builder( + hyper_util::rt::TokioExecutor::new() + ).build_http(); + + // Construct the request URI pointing to the guard proxy + let uri = format!("http://{}/stream", guard_addr); + let request = Request::builder() + .method("GET") + .uri(&uri) + .header("Host", "example.com") // Required for routing + .header("Accept", "text/event-stream") + .body(Full::::new(Bytes::new())) + .expect("Failed to build request"); + + println!("Making request through guard proxy: {}", uri); + + // Start the request + let response_future = client.request(request); + + // This is the key test: if streaming works correctly, we should get a response + // immediately when the server sends the first chunk. If streaming is broken + // (response is buffered), this will timeout because the server never closes + // the connection (it's an infinite stream). + let response_result = tokio::time::timeout(test_timeout, response_future).await; + + match response_result { + Ok(Ok(response)) => { + println!("✅ Got response immediately: {}", response.status()); + + // If we get here, streaming is working. Let's verify we can read data + let (parts, body) = response.into_parts(); + + // Try to read the first chunk with a timeout + let mut body_stream = body.into_data_stream(); + let first_chunk_result = tokio::time::timeout( + Duration::from_millis(500), + body_stream.next() + ).await; + + match first_chunk_result { + Ok(Some(Ok(chunk))) => { + let chunk_str = String::from_utf8_lossy(&chunk); + println!("✅ Received first chunk: {}", chunk_str); + assert!(chunk_str.contains("data: "), "Chunk should contain streaming data"); + println!("✅ Streaming is working! Received {} bytes of data", chunk.len()); + + // If we got this far, streaming is working correctly! + // The test was designed to timeout if streaming was broken + } + Ok(Some(Err(e))) => { + panic!("❌ Error reading stream chunk: {}", e); + } + Ok(None) => { + panic!("❌ Stream ended unexpectedly"); + } + Err(_) => { + panic!("❌ Timeout reading first chunk - streaming not working properly"); + } + } + } + Ok(Err(e)) => { + panic!("❌ HTTP request failed: {}", e); + } + Err(_) => { + // This is what we expect to happen when streaming is broken + println!("❌ Test timed out after {}s - streaming is NOT working!", test_timeout.as_secs()); + println!("❌ This indicates the proxy is buffering the entire response before returning it"); + panic!("Streaming response test timed out - proxy is buffering responses instead of streaming"); + } + } + + // Close the message sender to shut down the server + drop(message_sender); +} + +async fn start_streaming_server() -> (SocketAddr, mpsc::Sender) { + // Create a channel for sending messages to the streaming endpoint + let (message_tx, _message_rx) = mpsc::channel::(100); + + // Bind to a random port + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind"); + let addr = listener.local_addr().expect("Failed to get local address"); + + // Spawn the server task + tokio::spawn(async move { + println!("Streaming server: Started and waiting for connections"); + + loop { + // Accept connections + let accept_result = listener.accept().await; + + // Handle the connection + let (stream, _remote_addr) = match accept_result { + Ok(conn) => { + println!("Streaming server: Accepted connection from {}", conn.1); + conn + } + Err(e) => { + eprintln!("Streaming server: Error accepting connection: {}", e); + continue; + } + }; + + // Convert stream to TokioIo + let socket = TokioIo::new(stream); + + // Spawn a task to handle the connection + tokio::spawn(async move { + let service = service_fn(move |req: Request| { + async move { + println!("Streaming server: Received request: {} {}", req.method(), req.uri()); + + // Check if this is a streaming request + if req.uri().path() != "/stream" { + return Ok::<_, std::convert::Infallible>( + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::from("Not found"))) + .unwrap(), + ); + } + + println!("Streaming server: Setting up streaming response"); + + // Create a large response that will take time to fully buffer + // This simulates streaming behavior - the proxy should return this immediately + // but if it buffers, it will wait for the full response + + // Create a large response to simulate a slow streaming endpoint + let mut large_data = String::new(); + large_data.push_str("data: stream-started\n\n"); + + // Add a lot of data to make buffering take noticeable time + for i in 0..1000 { + large_data.push_str(&format!("data: chunk-{}\n\n", i)); + } + + // Add a delay to simulate network latency + tokio::time::sleep(Duration::from_millis(500)).await; + + println!("Streaming server: Returning large streaming response"); + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .header("Transfer-Encoding", "chunked") + .body(Full::new(Bytes::from(large_data))) + .unwrap()) + } + }); + + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(socket, service) + .await + { + eprintln!("Streaming server: Error serving connection: {:?}", err); + } + }); + } + }); + + // Sleep a brief moment to ensure the server is ready + tokio::time::sleep(Duration::from_millis(100)).await; + + (addr, message_tx) +} + +// Helper function to create a routing function for our streaming test +fn create_streaming_routing_fn(server_addr: SocketAddr) -> RoutingFn { + std::sync::Arc::new( + move |_hostname: &str, + path: &str, + _port_type: rivet_guard_core::proxy_service::PortType| { + Box::pin(async move { + println!("Guard: Routing request - path: {}", path); + + if path == "/stream" { + let target = RouteTarget { + actor_id: Some(Uuid::new_v4()), + server_id: Some(Uuid::new_v4()), + host: server_addr.ip().to_string(), + port: server_addr.port(), + path: path.to_string(), + }; + + Ok(RoutingOutput::Route(RouteConfig { + targets: vec![target], + timeout: RoutingTimeout { + routing_timeout: 30, // 30 seconds for routing timeout + }, + })) + } else { + use rivet_guard_core::proxy_service::StructuredResponse; + Ok(RoutingOutput::Response(StructuredResponse { + status: StatusCode::NOT_FOUND, + message: std::borrow::Cow::Borrowed("Not found"), + docs: None, + })) + } + }) + }, + ) +} \ No newline at end of file