Skip to content

Commit 90ab6ac

Browse files
committed
feat(guard): support streaming responses (#2667)
<!-- Please make sure there is an issue that this PR is correlated to. --> ## Changes <!-- If there are frontend changes, please include screenshots. -->
1 parent f872ff4 commit 90ab6ac

File tree

6 files changed

+381
-37
lines changed

6 files changed

+381
-37
lines changed

Cargo.lock

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/edge/infra/guard/core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ license.workspace = true
99
global-error.workspace = true
1010
bytes = "1.6.0"
1111
futures = "0.3.30"
12+
http-body = "1.0.0"
1213
http-body-util = "0.1.1"
1314
hyper = { version = "1.6.0", features = ["full", "http1", "http2"] }
1415
hyper-util = { version = "0.1.10", features = ["full"] }
@@ -39,3 +40,4 @@ clickhouse-inserter.workspace = true
3940
futures-util = "0.3.30"
4041
futures = "0.3.30"
4142
reqwest = { version = "0.11.27", features = ["native-tls"] }
43+
tokio-stream = "0.1.15"

packages/edge/infra/guard/core/src/proxy_service.rs

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use tokio::sync::Mutex;
1313
use bytes::Bytes;
1414
use futures_util::{SinkExt, StreamExt};
1515
use global_error::*;
16-
use http_body_util::Full;
16+
use http_body_util::{BodyExt, Full};
1717
use hyper::body::Incoming as BodyIncoming;
1818
use hyper::header::HeaderName;
1919
use hyper::{Request, Response, StatusCode};
@@ -34,6 +34,68 @@ const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
3434
const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes
3535
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour
3636

37+
/// Response body type that can handle both streaming and buffered responses
38+
#[derive(Debug)]
39+
pub enum ResponseBody {
40+
/// Buffered response body
41+
Full(Full<Bytes>),
42+
/// Streaming response body
43+
Incoming(BodyIncoming),
44+
}
45+
46+
impl http_body::Body for ResponseBody {
47+
type Data = Bytes;
48+
type Error = Box<dyn std::error::Error + Send + Sync>;
49+
50+
fn poll_frame(
51+
self: std::pin::Pin<&mut Self>,
52+
cx: &mut std::task::Context<'_>,
53+
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
54+
match self.get_mut() {
55+
ResponseBody::Full(body) => {
56+
let pin = std::pin::Pin::new(body);
57+
match pin.poll_frame(cx) {
58+
std::task::Poll::Ready(Some(Ok(frame))) => {
59+
std::task::Poll::Ready(Some(Ok(frame)))
60+
}
61+
std::task::Poll::Ready(Some(Err(e))) => {
62+
std::task::Poll::Ready(Some(Err(Box::new(e))))
63+
}
64+
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
65+
std::task::Poll::Pending => std::task::Poll::Pending,
66+
}
67+
}
68+
ResponseBody::Incoming(body) => {
69+
let pin = std::pin::Pin::new(body);
70+
match pin.poll_frame(cx) {
71+
std::task::Poll::Ready(Some(Ok(frame))) => {
72+
std::task::Poll::Ready(Some(Ok(frame)))
73+
}
74+
std::task::Poll::Ready(Some(Err(e))) => {
75+
std::task::Poll::Ready(Some(Err(Box::new(e))))
76+
}
77+
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
78+
std::task::Poll::Pending => std::task::Poll::Pending,
79+
}
80+
}
81+
}
82+
}
83+
84+
fn is_end_stream(&self) -> bool {
85+
match self {
86+
ResponseBody::Full(body) => body.is_end_stream(),
87+
ResponseBody::Incoming(body) => body.is_end_stream(),
88+
}
89+
}
90+
91+
fn size_hint(&self) -> http_body::SizeHint {
92+
match self {
93+
ResponseBody::Full(body) => body.size_hint(),
94+
ResponseBody::Incoming(body) => body.size_hint(),
95+
}
96+
}
97+
}
98+
3799
// Routing types
38100
#[derive(Clone, Debug)]
39101
pub struct RouteTarget {
@@ -71,7 +133,7 @@ pub struct StructuredResponse {
71133
}
72134

73135
impl StructuredResponse {
74-
pub fn build_response(&self) -> GlobalResult<Response<Full<Bytes>>> {
136+
pub fn build_response(&self) -> GlobalResult<Response<ResponseBody>> {
75137
let mut body = StdHashMap::new();
76138
body.insert("message", self.message.clone().into_owned());
77139

@@ -85,7 +147,7 @@ impl StructuredResponse {
85147
let response = Response::builder()
86148
.status(self.status)
87149
.header(hyper::header::CONTENT_TYPE, "application/json")
88-
.body(Full::new(bytes))?;
150+
.body(ResponseBody::Full(Full::new(bytes)))?;
89151

90152
Ok(response)
91153
}
@@ -605,7 +667,7 @@ impl ProxyService {
605667
&self,
606668
req: Request<BodyIncoming>,
607669
request_context: &mut RequestContext,
608-
) -> GlobalResult<Response<Full<Bytes>>> {
670+
) -> GlobalResult<Response<ResponseBody>> {
609671
let host = req
610672
.headers()
611673
.get(hyper::header::HOST)
@@ -641,7 +703,7 @@ impl ProxyService {
641703
tracing::error!(?err, "Routing error");
642704
return Ok(Response::builder()
643705
.status(StatusCode::BAD_GATEWAY)
644-
.body(Full::<Bytes>::new(Bytes::new()))?);
706+
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))?);
645707
}
646708
};
647709

@@ -669,14 +731,14 @@ impl ProxyService {
669731
let res = if !self.state.check_rate_limit(client_ip, &actor_id).await? {
670732
Response::builder()
671733
.status(StatusCode::TOO_MANY_REQUESTS)
672-
.body(Full::<Bytes>::new(Bytes::new()))
734+
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))
673735
.map_err(Into::into)
674736
}
675737
// Check in-flight limit
676738
else if !self.state.acquire_in_flight(client_ip, &actor_id).await? {
677739
Response::builder()
678740
.status(StatusCode::TOO_MANY_REQUESTS)
679-
.body(Full::<Bytes>::new(Bytes::new()))
741+
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))
680742
.map_err(Into::into)
681743
} else {
682744
// Increment metrics
@@ -782,7 +844,7 @@ impl ProxyService {
782844
req: Request<BodyIncoming>,
783845
mut target: RouteTarget,
784846
request_context: &mut RequestContext,
785-
) -> GlobalResult<Response<Full<Bytes>>> {
847+
) -> GlobalResult<Response<ResponseBody>> {
786848
// Get middleware config for this actor if it exists
787849
let middleware_config = match &target.actor_id {
788850
Some(actor_id) => self.state.get_middleware_config(actor_id).await?,
@@ -894,20 +956,38 @@ impl ProxyService {
894956
Ok(Ok(resp)) => {
895957
let response_receive_time = request_send_start.elapsed();
896958

897-
// Convert the hyper::body::Incoming to http_body_util::Full<Bytes>
898959
let (parts, body) = resp.into_parts();
899960

900-
// Read the response body
901-
let body_bytes = match http_body_util::BodyExt::collect(body).await {
902-
Ok(collected) => collected.to_bytes(),
903-
Err(_) => Bytes::new(),
904-
};
961+
// Check if this is a streaming response by examining headers
962+
// let is_streaming = parts.headers.get("content-type")
963+
// .and_then(|ct| ct.to_str().ok())
964+
// .map(|ct| ct.contains("text/event-stream") || ct.contains("application/stream"))
965+
// .unwrap_or(false);
966+
let is_streaming = true;
967+
968+
if is_streaming {
969+
// For streaming responses, pass through the body without buffering
970+
tracing::debug!("Detected streaming response, preserving stream");
905971

906-
// Set actual response body size in analytics
907-
request_context.guard_response_body_bytes = Some(body_bytes.len() as u64);
972+
// We can't easily calculate response size for streaming, so set it to None
973+
request_context.guard_response_body_bytes = None;
908974

909-
let full_body = Full::new(body_bytes);
910-
return Ok(Response::from_parts(parts, full_body));
975+
let streaming_body = ResponseBody::Incoming(body);
976+
return Ok(Response::from_parts(parts, streaming_body));
977+
} else {
978+
// For non-streaming responses, buffer as before
979+
let body_bytes = match BodyExt::collect(body).await {
980+
Ok(collected) => collected.to_bytes(),
981+
Err(_) => Bytes::new(),
982+
};
983+
984+
// Set actual response body size in analytics
985+
request_context.guard_response_body_bytes =
986+
Some(body_bytes.len() as u64);
987+
988+
let full_body = ResponseBody::Full(Full::new(body_bytes));
989+
return Ok(Response::from_parts(parts, full_body));
990+
}
911991
}
912992
Ok(Err(err)) => {
913993
if !err.is_connect() || attempts >= max_attempts {
@@ -944,7 +1024,9 @@ impl ProxyService {
9441024
tracing::error!(?err, "Routing error");
9451025
return Ok(Response::builder()
9461026
.status(StatusCode::BAD_GATEWAY)
947-
.body(Full::<Bytes>::new(Bytes::new()))?);
1027+
.body(ResponseBody::Full(Full::<Bytes>::new(
1028+
Bytes::new(),
1029+
)))?);
9481030
}
9491031
};
9501032

@@ -980,7 +1062,7 @@ impl ProxyService {
9801062

9811063
Ok(Response::builder()
9821064
.status(status_code)
983-
.body(Full::<Bytes>::new(Bytes::new()))?)
1065+
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))?)
9841066
}
9851067

9861068
// Common function to build a request URI and headers
@@ -1033,7 +1115,7 @@ impl ProxyService {
10331115
req: Request<BodyIncoming>,
10341116
mut target: RouteTarget,
10351117
_request_context: &mut RequestContext,
1036-
) -> GlobalResult<Response<Full<Bytes>>> {
1118+
) -> GlobalResult<Response<ResponseBody>> {
10371119
// Get actor and server IDs for metrics and middleware
10381120
let actor_id = target.actor_id;
10391121
let server_id = target.server_id;
@@ -1606,15 +1688,18 @@ impl ProxyService {
16061688
// Create a new response with an empty body - WebSocket upgrades don't need a body
16071689
Ok(Response::from_parts(
16081690
parts,
1609-
Full::<Bytes>::new(Bytes::new()),
1691+
ResponseBody::Full(Full::<Bytes>::new(Bytes::new())),
16101692
))
16111693
}
16121694
}
16131695

16141696
impl ProxyService {
16151697
// Process an individual request
16161698
#[tracing::instrument(skip_all)]
1617-
pub async fn process(&self, req: Request<BodyIncoming>) -> GlobalResult<Response<Full<Bytes>>> {
1699+
pub async fn process(
1700+
&self,
1701+
req: Request<BodyIncoming>,
1702+
) -> GlobalResult<Response<ResponseBody>> {
16181703
// Create request context for analytics tracking
16191704
let mut request_context = RequestContext::new(self.state.clickhouse_inserter.clone());
16201705

packages/edge/infra/guard/core/tests/common/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use hyper_util::rt::TokioIo;
88
use rivet_guard_core::{
99
proxy_service::{
1010
MaxInFlightConfig, MiddlewareConfig, MiddlewareFn, MiddlewareResponse, RateLimitConfig,
11-
RetryConfig, RouteTarget, RoutingFn, RoutingResponse, RoutingResult, RoutingTimeout,
11+
RetryConfig, RouteConfig, RouteTarget, RoutingFn, RoutingOutput, RoutingTimeout,
1212
TimeoutConfig,
1313
},
1414
GlobalErrorWrapper,
@@ -445,7 +445,7 @@ pub fn create_test_routing_fn(test_server: &TestServer) -> RoutingFn {
445445
path: path.to_string(),
446446
};
447447

448-
Ok(RoutingResponse::Ok(RouteConfig {
448+
Ok(RoutingOutput::Route(RouteConfig {
449449
targets: vec![target],
450450
timeout: RoutingTimeout {
451451
routing_timeout: 5, // 5 seconds for routing timeout
@@ -552,6 +552,7 @@ pub async fn start_guard_with_middleware(
552552
routing_fn_clone,
553553
middleware_fn_clone,
554554
rivet_guard_core::proxy_service::PortType::Http, // Default port type for tests
555+
None, // No ClickHouse inserter for tests
555556
));
556557

557558
// Run the server until shutdown signal

0 commit comments

Comments
 (0)