Skip to content

Commit 15d32dc

Browse files
cors: Allow async predicate for AllowOrigin
1 parent 71a600c commit 15d32dc

File tree

3 files changed

+214
-15
lines changed

3 files changed

+214
-15
lines changed

tower-http/src/cors/allow_origin.rs

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
use std::{array, fmt, sync::Arc};
2-
31
use http::{
42
header::{self, HeaderName, HeaderValue},
53
request::Parts as RequestParts,
64
};
5+
use pin_project_lite::pin_project;
6+
use std::{
7+
array, fmt,
8+
future::Future,
9+
pin::Pin,
10+
sync::Arc,
11+
task::{Context, Poll},
12+
};
713

814
use super::{Any, WILDCARD};
915

@@ -73,6 +79,21 @@ impl AllowOrigin {
7379
Self(OriginInner::Predicate(Arc::new(f)))
7480
}
7581

82+
/// Set the allowed origins from an async predicate
83+
///
84+
/// See [`CorsLayer::allow_origin`] for more details.
85+
///
86+
/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
87+
pub fn async_predicate<F, Fut>(f: F) -> Self
88+
where
89+
F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
90+
Fut: Future<Output = bool> + Send + Sync + 'static,
91+
{
92+
Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
93+
Box::pin((f.clone())(v, p))
94+
})))
95+
}
96+
7697
/// Allow any origin, by mirroring the request origin
7798
///
7899
/// This is equivalent to
@@ -90,18 +111,70 @@ impl AllowOrigin {
90111
matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
91112
}
92113

93-
pub(super) fn to_header(
114+
pub(super) fn to_future(
94115
&self,
95116
origin: Option<&HeaderValue>,
96117
parts: &RequestParts,
97-
) -> Option<(HeaderName, HeaderValue)> {
98-
let allow_origin = match &self.0 {
99-
OriginInner::Const(v) => v.clone(),
100-
OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(),
101-
OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(),
102-
};
118+
) -> AllowOriginFuture {
119+
let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
103120

104-
Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin))
121+
match &self.0 {
122+
OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
123+
OriginInner::List(l) => {
124+
AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
125+
}
126+
OriginInner::Predicate(c) => AllowOriginFuture::ok(
127+
origin
128+
.filter(|origin| c(origin, parts))
129+
.map(|o| (name, o.clone())),
130+
),
131+
OriginInner::AsyncPredicate(f) => {
132+
if let Some(origin) = origin.cloned() {
133+
let fut = f(origin.clone(), parts);
134+
AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
135+
} else {
136+
AllowOriginFuture::ok(None)
137+
}
138+
}
139+
}
140+
}
141+
}
142+
143+
pin_project! {
144+
#[project = AllowOriginFutureProj]
145+
pub(super) enum AllowOriginFuture {
146+
Ok{
147+
res: Option<(HeaderName, HeaderValue)>
148+
},
149+
Future{
150+
#[pin]
151+
future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
152+
},
153+
}
154+
}
155+
156+
impl AllowOriginFuture {
157+
fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
158+
Self::Ok { res }
159+
}
160+
161+
fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
162+
future: F,
163+
) -> Self {
164+
Self::Future {
165+
future: Box::pin(future),
166+
}
167+
}
168+
}
169+
170+
impl Future for AllowOriginFuture {
171+
type Output = Option<(HeaderName, HeaderValue)>;
172+
173+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174+
match self.project() {
175+
AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
176+
AllowOriginFutureProj::Future { future } => future.poll(cx),
177+
}
105178
}
106179
}
107180

@@ -111,6 +184,7 @@ impl fmt::Debug for AllowOrigin {
111184
OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
112185
OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
113186
OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
187+
OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
114188
}
115189
}
116190
}
@@ -147,6 +221,17 @@ enum OriginInner {
147221
Predicate(
148222
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
149223
),
224+
AsyncPredicate(
225+
Arc<
226+
dyn for<'a> Fn(
227+
HeaderValue,
228+
&'a RequestParts,
229+
) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
230+
+ Send
231+
+ Sync
232+
+ 'static,
233+
>,
234+
),
150235
}
151236

152237
impl Default for OriginInner {

tower-http/src/cors/mod.rs

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
5050
#![allow(clippy::enum_variant_names)]
5151

52+
use allow_origin::AllowOriginFuture;
5253
use bytes::{BufMut, BytesMut};
5354
use http::{
5455
header::{self, HeaderName},
@@ -326,6 +327,52 @@ impl CorsLayer {
326327
/// ));
327328
/// ```
328329
///
330+
/// You can also use an async closure:
331+
///
332+
/// ```
333+
/// # #[derive(Clone)]
334+
/// # struct Client;
335+
/// # fn get_api_client() -> Client {
336+
/// # Client
337+
/// # }
338+
/// # impl Client {
339+
/// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
340+
/// # vec![HeaderValue::from_static("http://example.com")]
341+
/// # }
342+
/// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
343+
/// # vec![HeaderValue::from_static("http://example.com")]
344+
/// # }
345+
/// # }
346+
/// use tower_http::cors::{CorsLayer, AllowOrigin};
347+
/// use http::{request::Parts as RequestParts, HeaderValue};
348+
///
349+
/// let client = get_api_client();
350+
///
351+
/// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
352+
/// |origin: HeaderValue, _request_parts: &RequestParts| async move {
353+
/// // fetch list of origins that are allowed
354+
/// let origins = client.fetch_allowed_origins().await;
355+
/// origins.contains(&origin)
356+
/// },
357+
/// ));
358+
///
359+
/// let client = get_api_client();
360+
///
361+
/// // if using &RequestParts, make sure all the values are owned
362+
/// // before passing into the future
363+
/// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
364+
/// |origin: HeaderValue, parts: &RequestParts| {
365+
/// let path = parts.uri.path().to_owned();
366+
///
367+
/// async move {
368+
/// // fetch list of origins that are allowed for this path
369+
/// let origins = client.fetch_allowed_origins_for_path(path).await;
370+
/// origins.contains(&origin)
371+
/// }
372+
/// },
373+
/// ));
374+
/// ```
375+
///
329376
/// Note that multiple calls to this method will override any previous
330377
/// calls.
331378
///
@@ -621,11 +668,13 @@ where
621668

622669
// These headers are applied to both preflight and subsequent regular CORS requests:
623670
// https://fetch.spec.whatwg.org/#http-responses
624-
headers.extend(self.layer.allow_origin.to_header(origin, &parts));
671+
625672
headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
626673
headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
627674
headers.extend(self.layer.vary.to_header());
628675

676+
let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);
677+
629678
// Return results immediately upon preflight request
630679
if parts.method == Method::OPTIONS {
631680
// These headers are applied only to preflight requests
@@ -634,7 +683,10 @@ where
634683
headers.extend(self.layer.max_age.to_header(origin, &parts));
635684

636685
ResponseFuture {
637-
inner: Kind::PreflightCall { headers },
686+
inner: Kind::PreflightCall {
687+
allow_origin_future,
688+
headers,
689+
},
638690
}
639691
} else {
640692
// This header is applied only to non-preflight requests
@@ -643,6 +695,8 @@ where
643695
let req = Request::from_parts(parts, body);
644696
ResponseFuture {
645697
inner: Kind::CorsCall {
698+
allow_origin_future,
699+
allow_origin_complete: false,
646700
future: self.inner.call(req),
647701
headers,
648702
},
@@ -663,11 +717,16 @@ pin_project! {
663717
#[project = KindProj]
664718
enum Kind<F> {
665719
CorsCall {
720+
#[pin]
721+
allow_origin_future: AllowOriginFuture,
722+
allow_origin_complete: bool,
666723
#[pin]
667724
future: F,
668725
headers: HeaderMap,
669726
},
670727
PreflightCall {
728+
#[pin]
729+
allow_origin_future: AllowOriginFuture,
671730
headers: HeaderMap,
672731
},
673732
}
@@ -682,7 +741,17 @@ where
682741

683742
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
684743
match self.project().inner.project() {
685-
KindProj::CorsCall { future, headers } => {
744+
KindProj::CorsCall {
745+
allow_origin_future,
746+
allow_origin_complete,
747+
future,
748+
headers,
749+
} => {
750+
if !*allow_origin_complete {
751+
headers.extend(ready!(allow_origin_future.poll(cx)));
752+
*allow_origin_complete = true;
753+
}
754+
686755
let mut response: Response<B> = ready!(future.poll(cx))?;
687756

688757
let response_headers = response.headers_mut();
@@ -697,7 +766,12 @@ where
697766

698767
Poll::Ready(Ok(response))
699768
}
700-
KindProj::PreflightCall { headers } => {
769+
KindProj::PreflightCall {
770+
allow_origin_future,
771+
headers,
772+
} => {
773+
headers.extend(ready!(allow_origin_future.poll(cx)));
774+
701775
let mut response = Response::new(B::default());
702776
mem::swap(response.headers_mut(), headers);
703777

tower-http/src/cors/tests.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::test_helpers::Body;
44
use http::{header, HeaderValue, Request, Response};
55
use tower::{service_fn, util::ServiceExt, Layer};
66

7-
use crate::cors::CorsLayer;
7+
use crate::cors::{AllowOrigin, CorsLayer};
88

99
#[tokio::test]
1010
#[allow(
@@ -31,3 +31,43 @@ async fn vary_set_by_inner_service() {
3131
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
3232
assert_eq!(vary_headers.next(), None);
3333
}
34+
35+
#[tokio::test]
36+
async fn test_allow_origin_async_predicate() {
37+
#[derive(Clone)]
38+
struct Client;
39+
40+
impl Client {
41+
async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
42+
vec![HeaderValue::from_static("http://example.com")]
43+
}
44+
}
45+
46+
let client = Client;
47+
48+
let allow_origin = AllowOrigin::async_predicate(|origin, parts| {
49+
let path = parts.uri.path().to_owned();
50+
51+
async move {
52+
let origins = client.fetch_allowed_origins_for_path(path).await;
53+
54+
origins.contains(&origin)
55+
}
56+
});
57+
58+
let valid_origin = HeaderValue::from_static("http://example.com");
59+
let parts = http::Request::new("hello world").into_parts().0;
60+
61+
let header = allow_origin
62+
.to_future(Some(&valid_origin), &parts)
63+
.await
64+
.unwrap();
65+
assert_eq!(header.0, header::ACCESS_CONTROL_ALLOW_ORIGIN);
66+
assert_eq!(header.1, valid_origin);
67+
68+
let invalid_origin = HeaderValue::from_static("http://example.org");
69+
let parts = http::Request::new("hello world").into_parts().0;
70+
71+
let res = allow_origin.to_future(Some(&invalid_origin), &parts).await;
72+
assert!(res.is_none());
73+
}

0 commit comments

Comments
 (0)