Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions packages/edge/infra/guard/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license.workspace = true
global-error.workspace = true
bytes = "1.6.0"
futures = "0.3.30"
http-body = "1.0.0"
http-body-util = "0.1.1"
hyper = { version = "1.6.0", features = ["full", "http1", "http2"] }
hyper-util = { version = "0.1.10", features = ["full"] }
Expand Down Expand Up @@ -39,3 +40,4 @@ clickhouse-inserter.workspace = true
futures-util = "0.3.30"
futures = "0.3.30"
reqwest = { version = "0.11.27", features = ["native-tls"] }
tokio-stream = "0.1.15"
131 changes: 108 additions & 23 deletions packages/edge/infra/guard/core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::sync::Mutex;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use global_error::*;
use http_body_util::Full;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming as BodyIncoming;
use hyper::header::HeaderName;
use hyper::{Request, Response, StatusCode};
Expand All @@ -34,6 +34,68 @@ const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour

/// Response body type that can handle both streaming and buffered responses
#[derive(Debug)]
pub enum ResponseBody {
/// Buffered response body
Full(Full<Bytes>),
/// Streaming response body
Incoming(BodyIncoming),
}

impl http_body::Body for ResponseBody {
type Data = Bytes;
type Error = Box<dyn std::error::Error + Send + Sync>;

fn poll_frame(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match self.get_mut() {
ResponseBody::Full(body) => {
let pin = std::pin::Pin::new(body);
match pin.poll_frame(cx) {
std::task::Poll::Ready(Some(Ok(frame))) => {
std::task::Poll::Ready(Some(Ok(frame)))
}
std::task::Poll::Ready(Some(Err(e))) => {
std::task::Poll::Ready(Some(Err(Box::new(e))))
}
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
ResponseBody::Incoming(body) => {
let pin = std::pin::Pin::new(body);
match pin.poll_frame(cx) {
std::task::Poll::Ready(Some(Ok(frame))) => {
std::task::Poll::Ready(Some(Ok(frame)))
}
std::task::Poll::Ready(Some(Err(e))) => {
std::task::Poll::Ready(Some(Err(Box::new(e))))
}
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
}

fn is_end_stream(&self) -> bool {
match self {
ResponseBody::Full(body) => body.is_end_stream(),
ResponseBody::Incoming(body) => body.is_end_stream(),
}
}

fn size_hint(&self) -> http_body::SizeHint {
match self {
ResponseBody::Full(body) => body.size_hint(),
ResponseBody::Incoming(body) => body.size_hint(),
}
}
}

// Routing types
#[derive(Clone, Debug)]
pub struct RouteTarget {
Expand Down Expand Up @@ -71,7 +133,7 @@ pub struct StructuredResponse {
}

impl StructuredResponse {
pub fn build_response(&self) -> GlobalResult<Response<Full<Bytes>>> {
pub fn build_response(&self) -> GlobalResult<Response<ResponseBody>> {
let mut body = StdHashMap::new();
body.insert("message", self.message.clone().into_owned());

Expand All @@ -85,7 +147,7 @@ impl StructuredResponse {
let response = Response::builder()
.status(self.status)
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(Full::new(bytes))?;
.body(ResponseBody::Full(Full::new(bytes)))?;

Ok(response)
}
Expand Down Expand Up @@ -605,7 +667,7 @@ impl ProxyService {
&self,
req: Request<BodyIncoming>,
request_context: &mut RequestContext,
) -> GlobalResult<Response<Full<Bytes>>> {
) -> GlobalResult<Response<ResponseBody>> {
let host = req
.headers()
.get(hyper::header::HOST)
Expand Down Expand Up @@ -641,7 +703,7 @@ impl ProxyService {
tracing::error!(?err, "Routing error");
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::<Bytes>::new(Bytes::new()))?);
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))?);
}
};

Expand Down Expand Up @@ -669,14 +731,14 @@ impl ProxyService {
let res = if !self.state.check_rate_limit(client_ip, &actor_id).await? {
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Full::<Bytes>::new(Bytes::new()))
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))
.map_err(Into::into)
}
// Check in-flight limit
else if !self.state.acquire_in_flight(client_ip, &actor_id).await? {
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Full::<Bytes>::new(Bytes::new()))
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))
.map_err(Into::into)
} else {
// Increment metrics
Expand Down Expand Up @@ -782,7 +844,7 @@ impl ProxyService {
req: Request<BodyIncoming>,
mut target: RouteTarget,
request_context: &mut RequestContext,
) -> GlobalResult<Response<Full<Bytes>>> {
) -> GlobalResult<Response<ResponseBody>> {
// Get middleware config for this actor if it exists
let middleware_config = match &target.actor_id {
Some(actor_id) => self.state.get_middleware_config(actor_id).await?,
Expand Down Expand Up @@ -894,20 +956,38 @@ impl ProxyService {
Ok(Ok(resp)) => {
let response_receive_time = request_send_start.elapsed();

// Convert the hyper::body::Incoming to http_body_util::Full<Bytes>
let (parts, body) = resp.into_parts();

// Read the response body
let body_bytes = match http_body_util::BodyExt::collect(body).await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
// Check if this is a streaming response by examining headers
// let is_streaming = parts.headers.get("content-type")
// .and_then(|ct| ct.to_str().ok())
// .map(|ct| ct.contains("text/event-stream") || ct.contains("application/stream"))
// .unwrap_or(false);
let is_streaming = true;

if is_streaming {
// For streaming responses, pass through the body without buffering
tracing::debug!("Detected streaming response, preserving stream");

// Set actual response body size in analytics
request_context.guard_response_body_bytes = Some(body_bytes.len() as u64);
// We can't easily calculate response size for streaming, so set it to None
request_context.guard_response_body_bytes = None;

let full_body = Full::new(body_bytes);
return Ok(Response::from_parts(parts, full_body));
let streaming_body = ResponseBody::Incoming(body);
return Ok(Response::from_parts(parts, streaming_body));
} else {
// For non-streaming responses, buffer as before
let body_bytes = match BodyExt::collect(body).await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};

// Set actual response body size in analytics
request_context.guard_response_body_bytes =
Some(body_bytes.len() as u64);

let full_body = ResponseBody::Full(Full::new(body_bytes));
return Ok(Response::from_parts(parts, full_body));
}
}
Ok(Err(err)) => {
if !err.is_connect() || attempts >= max_attempts {
Expand Down Expand Up @@ -944,7 +1024,9 @@ impl ProxyService {
tracing::error!(?err, "Routing error");
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::<Bytes>::new(Bytes::new()))?);
.body(ResponseBody::Full(Full::<Bytes>::new(
Bytes::new(),
)))?);
}
};

Expand Down Expand Up @@ -980,7 +1062,7 @@ impl ProxyService {

Ok(Response::builder()
.status(status_code)
.body(Full::<Bytes>::new(Bytes::new()))?)
.body(ResponseBody::Full(Full::<Bytes>::new(Bytes::new())))?)
}

// Common function to build a request URI and headers
Expand Down Expand Up @@ -1033,7 +1115,7 @@ impl ProxyService {
req: Request<BodyIncoming>,
mut target: RouteTarget,
_request_context: &mut RequestContext,
) -> GlobalResult<Response<Full<Bytes>>> {
) -> GlobalResult<Response<ResponseBody>> {
// Get actor and server IDs for metrics and middleware
let actor_id = target.actor_id;
let server_id = target.server_id;
Expand Down Expand Up @@ -1606,15 +1688,18 @@ impl ProxyService {
// Create a new response with an empty body - WebSocket upgrades don't need a body
Ok(Response::from_parts(
parts,
Full::<Bytes>::new(Bytes::new()),
ResponseBody::Full(Full::<Bytes>::new(Bytes::new())),
))
}
}

impl ProxyService {
// Process an individual request
#[tracing::instrument(skip_all)]
pub async fn process(&self, req: Request<BodyIncoming>) -> GlobalResult<Response<Full<Bytes>>> {
pub async fn process(
&self,
req: Request<BodyIncoming>,
) -> GlobalResult<Response<ResponseBody>> {
// Create request context for analytics tracking
let mut request_context = RequestContext::new(self.state.clickhouse_inserter.clone());

Expand Down
5 changes: 3 additions & 2 deletions packages/edge/infra/guard/core/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use hyper_util::rt::TokioIo;
use rivet_guard_core::{
proxy_service::{
MaxInFlightConfig, MiddlewareConfig, MiddlewareFn, MiddlewareResponse, RateLimitConfig,
RetryConfig, RouteTarget, RoutingFn, RoutingResponse, RoutingResult, RoutingTimeout,
RetryConfig, RouteConfig, RouteTarget, RoutingFn, RoutingOutput, RoutingTimeout,
TimeoutConfig,
},
GlobalErrorWrapper,
Expand Down Expand Up @@ -445,7 +445,7 @@ pub fn create_test_routing_fn(test_server: &TestServer) -> RoutingFn {
path: path.to_string(),
};

Ok(RoutingResponse::Ok(RouteConfig {
Ok(RoutingOutput::Route(RouteConfig {
targets: vec![target],
timeout: RoutingTimeout {
routing_timeout: 5, // 5 seconds for routing timeout
Expand Down Expand Up @@ -552,6 +552,7 @@ pub async fn start_guard_with_middleware(
routing_fn_clone,
middleware_fn_clone,
rivet_guard_core::proxy_service::PortType::Http, // Default port type for tests
None, // No ClickHouse inserter for tests
));

// Run the server until shutdown signal
Expand Down
Loading
Loading