1
- use std:: { array, fmt, sync:: Arc } ;
2
-
3
1
use http:: {
4
2
header:: { self , HeaderName , HeaderValue } ,
5
3
request:: Parts as RequestParts ,
6
4
} ;
5
+ use pin_project_lite:: pin_project;
6
+ use std:: {
7
+ array, fmt,
8
+ future:: Future ,
9
+ pin:: Pin ,
10
+ sync:: Arc ,
11
+ task:: { Context , Poll } ,
12
+ } ;
7
13
8
14
use super :: { Any , WILDCARD } ;
9
15
@@ -73,6 +79,21 @@ impl AllowOrigin {
73
79
Self ( OriginInner :: Predicate ( Arc :: new ( f) ) )
74
80
}
75
81
82
+ /// Set the allowed origins from an async predicate
83
+ ///
84
+ /// See [`CorsLayer::allow_origin`] for more details.
85
+ ///
86
+ /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
87
+ pub fn async_predicate < F , Fut > ( f : F ) -> Self
88
+ where
89
+ F : FnOnce ( HeaderValue , & RequestParts ) -> Fut + Send + Sync + ' static + Clone ,
90
+ Fut : Future < Output = bool > + Send + Sync + ' static ,
91
+ {
92
+ Self ( OriginInner :: AsyncPredicate ( Arc :: new ( move |v, p| {
93
+ Box :: pin ( ( f. clone ( ) ) ( v, p) )
94
+ } ) ) )
95
+ }
96
+
76
97
/// Allow any origin, by mirroring the request origin
77
98
///
78
99
/// This is equivalent to
@@ -90,18 +111,70 @@ impl AllowOrigin {
90
111
matches ! ( & self . 0 , OriginInner :: Const ( v) if v == WILDCARD )
91
112
}
92
113
93
- pub ( super ) fn to_header (
114
+ pub ( super ) fn to_future (
94
115
& self ,
95
116
origin : Option < & HeaderValue > ,
96
117
parts : & RequestParts ,
97
- ) -> Option < ( HeaderName , HeaderValue ) > {
98
- let allow_origin = match & self . 0 {
99
- OriginInner :: Const ( v) => v. clone ( ) ,
100
- OriginInner :: List ( l) => origin. filter ( |o| l. contains ( o) ) ?. clone ( ) ,
101
- OriginInner :: Predicate ( c) => origin. filter ( |origin| c ( origin, parts) ) ?. clone ( ) ,
102
- } ;
118
+ ) -> AllowOriginFuture {
119
+ let name = header:: ACCESS_CONTROL_ALLOW_ORIGIN ;
103
120
104
- Some ( ( header:: ACCESS_CONTROL_ALLOW_ORIGIN , allow_origin) )
121
+ match & self . 0 {
122
+ OriginInner :: Const ( v) => AllowOriginFuture :: ok ( Some ( ( name, v. clone ( ) ) ) ) ,
123
+ OriginInner :: List ( l) => {
124
+ AllowOriginFuture :: ok ( origin. filter ( |o| l. contains ( o) ) . map ( |o| ( name, o. clone ( ) ) ) )
125
+ }
126
+ OriginInner :: Predicate ( c) => AllowOriginFuture :: ok (
127
+ origin
128
+ . filter ( |origin| c ( origin, parts) )
129
+ . map ( |o| ( name, o. clone ( ) ) ) ,
130
+ ) ,
131
+ OriginInner :: AsyncPredicate ( f) => {
132
+ if let Some ( origin) = origin. cloned ( ) {
133
+ let fut = f ( origin. clone ( ) , parts) ;
134
+ AllowOriginFuture :: fut ( async move { fut. await . then_some ( ( name, origin) ) } )
135
+ } else {
136
+ AllowOriginFuture :: ok ( None )
137
+ }
138
+ }
139
+ }
140
+ }
141
+ }
142
+
143
+ pin_project ! {
144
+ #[ project = AllowOriginFutureProj ]
145
+ pub ( super ) enum AllowOriginFuture {
146
+ Ok {
147
+ res: Option <( HeaderName , HeaderValue ) >
148
+ } ,
149
+ Future {
150
+ #[ pin]
151
+ future: Pin <Box <dyn Future <Output = Option <( HeaderName , HeaderValue ) >> + Send + ' static >>
152
+ } ,
153
+ }
154
+ }
155
+
156
+ impl AllowOriginFuture {
157
+ fn ok ( res : Option < ( HeaderName , HeaderValue ) > ) -> Self {
158
+ Self :: Ok { res }
159
+ }
160
+
161
+ fn fut < F : Future < Output = Option < ( HeaderName , HeaderValue ) > > + Send + ' static > (
162
+ future : F ,
163
+ ) -> Self {
164
+ Self :: Future {
165
+ future : Box :: pin ( future) ,
166
+ }
167
+ }
168
+ }
169
+
170
+ impl Future for AllowOriginFuture {
171
+ type Output = Option < ( HeaderName , HeaderValue ) > ;
172
+
173
+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
174
+ match self . project ( ) {
175
+ AllowOriginFutureProj :: Ok { res } => Poll :: Ready ( res. take ( ) ) ,
176
+ AllowOriginFutureProj :: Future { future } => future. poll ( cx) ,
177
+ }
105
178
}
106
179
}
107
180
@@ -111,6 +184,7 @@ impl fmt::Debug for AllowOrigin {
111
184
OriginInner :: Const ( inner) => f. debug_tuple ( "Const" ) . field ( inner) . finish ( ) ,
112
185
OriginInner :: List ( inner) => f. debug_tuple ( "List" ) . field ( inner) . finish ( ) ,
113
186
OriginInner :: Predicate ( _) => f. debug_tuple ( "Predicate" ) . finish ( ) ,
187
+ OriginInner :: AsyncPredicate ( _) => f. debug_tuple ( "AsyncPredicate" ) . finish ( ) ,
114
188
}
115
189
}
116
190
}
@@ -147,6 +221,17 @@ enum OriginInner {
147
221
Predicate (
148
222
Arc < dyn for < ' a > Fn ( & ' a HeaderValue , & ' a RequestParts ) -> bool + Send + Sync + ' static > ,
149
223
) ,
224
+ AsyncPredicate (
225
+ Arc <
226
+ dyn for < ' a > Fn (
227
+ HeaderValue ,
228
+ & ' a RequestParts ,
229
+ ) -> Pin < Box < dyn Future < Output = bool > + Send + ' static > >
230
+ + Send
231
+ + Sync
232
+ + ' static ,
233
+ > ,
234
+ ) ,
150
235
}
151
236
152
237
impl Default for OriginInner {
0 commit comments