@@ -13,7 +13,7 @@ use tokio::sync::Mutex;
13
13
use bytes:: Bytes ;
14
14
use futures_util:: { SinkExt , StreamExt } ;
15
15
use global_error:: * ;
16
- use http_body_util:: Full ;
16
+ use http_body_util:: { BodyExt , Full } ;
17
17
use hyper:: body:: Incoming as BodyIncoming ;
18
18
use hyper:: header:: HeaderName ;
19
19
use hyper:: { Request , Response , StatusCode } ;
@@ -34,6 +34,68 @@ const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
34
34
const ROUTE_CACHE_TTL : Duration = Duration :: from_secs ( 60 * 10 ) ; // 10 minutes
35
35
const PROXY_STATE_CACHE_TTL : Duration = Duration :: from_secs ( 60 * 60 ) ; // 1 hour
36
36
37
+ /// Response body type that can handle both streaming and buffered responses
38
+ #[ derive( Debug ) ]
39
+ pub enum ResponseBody {
40
+ /// Buffered response body
41
+ Full ( Full < Bytes > ) ,
42
+ /// Streaming response body
43
+ Incoming ( BodyIncoming ) ,
44
+ }
45
+
46
+ impl http_body:: Body for ResponseBody {
47
+ type Data = Bytes ;
48
+ type Error = Box < dyn std:: error:: Error + Send + Sync > ;
49
+
50
+ fn poll_frame (
51
+ self : std:: pin:: Pin < & mut Self > ,
52
+ cx : & mut std:: task:: Context < ' _ > ,
53
+ ) -> std:: task:: Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > > {
54
+ match self . get_mut ( ) {
55
+ ResponseBody :: Full ( body) => {
56
+ let pin = std:: pin:: Pin :: new ( body) ;
57
+ match pin. poll_frame ( cx) {
58
+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) ) => {
59
+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) )
60
+ }
61
+ std:: task:: Poll :: Ready ( Some ( Err ( e) ) ) => {
62
+ std:: task:: Poll :: Ready ( Some ( Err ( Box :: new ( e) ) ) )
63
+ }
64
+ std:: task:: Poll :: Ready ( None ) => std:: task:: Poll :: Ready ( None ) ,
65
+ std:: task:: Poll :: Pending => std:: task:: Poll :: Pending ,
66
+ }
67
+ }
68
+ ResponseBody :: Incoming ( body) => {
69
+ let pin = std:: pin:: Pin :: new ( body) ;
70
+ match pin. poll_frame ( cx) {
71
+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) ) => {
72
+ std:: task:: Poll :: Ready ( Some ( Ok ( frame) ) )
73
+ }
74
+ std:: task:: Poll :: Ready ( Some ( Err ( e) ) ) => {
75
+ std:: task:: Poll :: Ready ( Some ( Err ( Box :: new ( e) ) ) )
76
+ }
77
+ std:: task:: Poll :: Ready ( None ) => std:: task:: Poll :: Ready ( None ) ,
78
+ std:: task:: Poll :: Pending => std:: task:: Poll :: Pending ,
79
+ }
80
+ }
81
+ }
82
+ }
83
+
84
+ fn is_end_stream ( & self ) -> bool {
85
+ match self {
86
+ ResponseBody :: Full ( body) => body. is_end_stream ( ) ,
87
+ ResponseBody :: Incoming ( body) => body. is_end_stream ( ) ,
88
+ }
89
+ }
90
+
91
+ fn size_hint ( & self ) -> http_body:: SizeHint {
92
+ match self {
93
+ ResponseBody :: Full ( body) => body. size_hint ( ) ,
94
+ ResponseBody :: Incoming ( body) => body. size_hint ( ) ,
95
+ }
96
+ }
97
+ }
98
+
37
99
// Routing types
38
100
#[ derive( Clone , Debug ) ]
39
101
pub struct RouteTarget {
@@ -71,7 +133,7 @@ pub struct StructuredResponse {
71
133
}
72
134
73
135
impl StructuredResponse {
74
- pub fn build_response ( & self ) -> GlobalResult < Response < Full < Bytes > > > {
136
+ pub fn build_response ( & self ) -> GlobalResult < Response < ResponseBody > > {
75
137
let mut body = StdHashMap :: new ( ) ;
76
138
body. insert ( "message" , self . message . clone ( ) . into_owned ( ) ) ;
77
139
@@ -85,7 +147,7 @@ impl StructuredResponse {
85
147
let response = Response :: builder ( )
86
148
. status ( self . status )
87
149
. header ( hyper:: header:: CONTENT_TYPE , "application/json" )
88
- . body ( Full :: new ( bytes) ) ?;
150
+ . body ( ResponseBody :: Full ( Full :: new ( bytes) ) ) ?;
89
151
90
152
Ok ( response)
91
153
}
@@ -605,7 +667,7 @@ impl ProxyService {
605
667
& self ,
606
668
req : Request < BodyIncoming > ,
607
669
request_context : & mut RequestContext ,
608
- ) -> GlobalResult < Response < Full < Bytes > > > {
670
+ ) -> GlobalResult < Response < ResponseBody > > {
609
671
let host = req
610
672
. headers ( )
611
673
. get ( hyper:: header:: HOST )
@@ -641,7 +703,7 @@ impl ProxyService {
641
703
tracing:: error!( ?err, "Routing error" ) ;
642
704
return Ok ( Response :: builder ( )
643
705
. status ( StatusCode :: BAD_GATEWAY )
644
- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?) ;
706
+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ) ?) ;
645
707
}
646
708
} ;
647
709
@@ -669,14 +731,14 @@ impl ProxyService {
669
731
let res = if !self . state . check_rate_limit ( client_ip, & actor_id) . await ? {
670
732
Response :: builder ( )
671
733
. status ( StatusCode :: TOO_MANY_REQUESTS )
672
- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) )
734
+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) )
673
735
. map_err ( Into :: into)
674
736
}
675
737
// Check in-flight limit
676
738
else if !self . state . acquire_in_flight ( client_ip, & actor_id) . await ? {
677
739
Response :: builder ( )
678
740
. status ( StatusCode :: TOO_MANY_REQUESTS )
679
- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) )
741
+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) )
680
742
. map_err ( Into :: into)
681
743
} else {
682
744
// Increment metrics
@@ -782,7 +844,7 @@ impl ProxyService {
782
844
req : Request < BodyIncoming > ,
783
845
mut target : RouteTarget ,
784
846
request_context : & mut RequestContext ,
785
- ) -> GlobalResult < Response < Full < Bytes > > > {
847
+ ) -> GlobalResult < Response < ResponseBody > > {
786
848
// Get middleware config for this actor if it exists
787
849
let middleware_config = match & target. actor_id {
788
850
Some ( actor_id) => self . state . get_middleware_config ( actor_id) . await ?,
@@ -894,20 +956,38 @@ impl ProxyService {
894
956
Ok ( Ok ( resp) ) => {
895
957
let response_receive_time = request_send_start. elapsed ( ) ;
896
958
897
- // Convert the hyper::body::Incoming to http_body_util::Full<Bytes>
898
959
let ( parts, body) = resp. into_parts ( ) ;
899
960
900
- // Read the response body
901
- let body_bytes = match http_body_util:: BodyExt :: collect ( body) . await {
902
- Ok ( collected) => collected. to_bytes ( ) ,
903
- Err ( _) => Bytes :: new ( ) ,
904
- } ;
961
+ // Check if this is a streaming response by examining headers
962
+ // let is_streaming = parts.headers.get("content-type")
963
+ // .and_then(|ct| ct.to_str().ok())
964
+ // .map(|ct| ct.contains("text/event-stream") || ct.contains("application/stream"))
965
+ // .unwrap_or(false);
966
+ let is_streaming = true ;
967
+
968
+ if is_streaming {
969
+ // For streaming responses, pass through the body without buffering
970
+ tracing:: debug!( "Detected streaming response, preserving stream" ) ;
905
971
906
- // Set actual response body size in analytics
907
- request_context. guard_response_body_bytes = Some ( body_bytes . len ( ) as u64 ) ;
972
+ // We can't easily calculate response size for streaming, so set it to None
973
+ request_context. guard_response_body_bytes = None ;
908
974
909
- let full_body = Full :: new ( body_bytes) ;
910
- return Ok ( Response :: from_parts ( parts, full_body) ) ;
975
+ let streaming_body = ResponseBody :: Incoming ( body) ;
976
+ return Ok ( Response :: from_parts ( parts, streaming_body) ) ;
977
+ } else {
978
+ // For non-streaming responses, buffer as before
979
+ let body_bytes = match BodyExt :: collect ( body) . await {
980
+ Ok ( collected) => collected. to_bytes ( ) ,
981
+ Err ( _) => Bytes :: new ( ) ,
982
+ } ;
983
+
984
+ // Set actual response body size in analytics
985
+ request_context. guard_response_body_bytes =
986
+ Some ( body_bytes. len ( ) as u64 ) ;
987
+
988
+ let full_body = ResponseBody :: Full ( Full :: new ( body_bytes) ) ;
989
+ return Ok ( Response :: from_parts ( parts, full_body) ) ;
990
+ }
911
991
}
912
992
Ok ( Err ( err) ) => {
913
993
if !err. is_connect ( ) || attempts >= max_attempts {
@@ -944,7 +1024,9 @@ impl ProxyService {
944
1024
tracing:: error!( ?err, "Routing error" ) ;
945
1025
return Ok ( Response :: builder ( )
946
1026
. status ( StatusCode :: BAD_GATEWAY )
947
- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?) ;
1027
+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new (
1028
+ Bytes :: new ( ) ,
1029
+ ) ) ) ?) ;
948
1030
}
949
1031
} ;
950
1032
@@ -980,7 +1062,7 @@ impl ProxyService {
980
1062
981
1063
Ok ( Response :: builder ( )
982
1064
. status ( status_code)
983
- . body ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ?)
1065
+ . body ( ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ) ?)
984
1066
}
985
1067
986
1068
// Common function to build a request URI and headers
@@ -1033,7 +1115,7 @@ impl ProxyService {
1033
1115
req : Request < BodyIncoming > ,
1034
1116
mut target : RouteTarget ,
1035
1117
_request_context : & mut RequestContext ,
1036
- ) -> GlobalResult < Response < Full < Bytes > > > {
1118
+ ) -> GlobalResult < Response < ResponseBody > > {
1037
1119
// Get actor and server IDs for metrics and middleware
1038
1120
let actor_id = target. actor_id ;
1039
1121
let server_id = target. server_id ;
@@ -1606,15 +1688,18 @@ impl ProxyService {
1606
1688
// Create a new response with an empty body - WebSocket upgrades don't need a body
1607
1689
Ok ( Response :: from_parts (
1608
1690
parts,
1609
- Full :: < Bytes > :: new ( Bytes :: new ( ) ) ,
1691
+ ResponseBody :: Full ( Full :: < Bytes > :: new ( Bytes :: new ( ) ) ) ,
1610
1692
) )
1611
1693
}
1612
1694
}
1613
1695
1614
1696
impl ProxyService {
1615
1697
// Process an individual request
1616
1698
#[ tracing:: instrument( skip_all) ]
1617
- pub async fn process ( & self , req : Request < BodyIncoming > ) -> GlobalResult < Response < Full < Bytes > > > {
1699
+ pub async fn process (
1700
+ & self ,
1701
+ req : Request < BodyIncoming > ,
1702
+ ) -> GlobalResult < Response < ResponseBody > > {
1618
1703
// Create request context for analytics tracking
1619
1704
let mut request_context = RequestContext :: new ( self . state . clickhouse_inserter . clone ( ) ) ;
1620
1705
0 commit comments