Skip to content

Commit 52ab2c4

Browse files
committed
add clienthooks
1 parent 8497fdb commit 52ab2c4

File tree

4 files changed

+193
-42
lines changed

4 files changed

+193
-42
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ version = "0.12.18"
1515
default-features = false
1616
features = ["stream"]
1717

18+
1819
[features]
1920
default = ["default-tls", "dep:hmac", "dep:sha2"]
2021
default-tls = ["reqwest/default-tls"]

src/s3/client.rs

Lines changed: 142 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
//! S3 client to perform bucket and object operations
1717
18+
use std::fmt::Debug;
1819
use std::fs::File;
1920
use std::io::prelude::*;
2021
use std::mem;
@@ -24,20 +25,22 @@ use std::sync::{Arc, OnceLock};
2425
use crate::s3::builders::{BucketExists, ComposeSource};
2526
use crate::s3::creds::Provider;
2627
use crate::s3::error::{Error, ErrorCode, ErrorResponse};
27-
use crate::s3::http::BaseUrl;
28+
use crate::s3::http::{BaseUrl, Url};
2829
use crate::s3::multimap::{Multimap, MultimapExt};
2930
use crate::s3::response::a_response_traits::{HasEtagFromHeaders, HasS3Fields};
3031
use crate::s3::response::*;
3132
use crate::s3::segmented_bytes::SegmentedBytes;
3233
use crate::s3::signer::sign_v4_s3;
3334
use crate::s3::utils::{EMPTY_SHA256, sha256_hash_sb, to_amz_date, utc_now};
35+
pub use crate::s3::client::hooks::ClientHooks;
3436

3537
use bytes::Bytes;
3638
use dashmap::DashMap;
3739
use http::HeaderMap;
38-
use hyper::http::Method;
40+
pub use hyper::http::Method;
3941
use rand::Rng;
4042
use reqwest::Body;
43+
pub use reqwest::{Error as ReqwestError, Response};
4144

4245
mod append_object;
4346
mod bucket_exists;
@@ -69,6 +72,7 @@ mod get_object_tagging;
6972
mod get_presigned_object_url;
7073
mod get_presigned_post_form_data;
7174
mod get_region;
75+
pub mod hooks;
7276
mod list_buckets;
7377
mod list_objects;
7478
mod listen_bucket_notification;
@@ -123,6 +127,7 @@ pub const MAX_MULTIPART_COUNT: u16 = 10_000;
123127
pub struct ClientBuilder {
124128
base_url: BaseUrl,
125129
provider: Option<Arc<dyn Provider + Send + Sync + 'static>>,
130+
client_hooks: Vec<Arc<dyn ClientHooks + Send + Sync + 'static>>,
126131
ssl_cert_file: Option<PathBuf>,
127132
ignore_cert_check: Option<bool>,
128133
app_info: Option<(String, String)>,
@@ -138,6 +143,13 @@ impl ClientBuilder {
138143
}
139144
}
140145

146+
/// Add a client hook to the builder. Hooks will be called after each other in
147+
/// order they were added.
148+
pub fn hook(mut self, hooks: Arc<dyn ClientHooks + Send + Sync + 'static>) -> Self {
149+
self.client_hooks.push(hooks);
150+
self
151+
}
152+
141153
/// Set the credential provider. If not, set anonymous access is used.
142154
pub fn provider<P: Provider + Send + Sync + 'static>(mut self, provider: Option<P>) -> Self {
143155
self.provider = provider.map(|p| Arc::new(p) as Arc<dyn Provider + Send + Sync + 'static>);
@@ -209,6 +221,7 @@ impl ClientBuilder {
209221
shared: Arc::new(SharedClientItems {
210222
base_url: self.base_url,
211223
provider: self.provider,
224+
client_hooks: self.client_hooks,
212225
..Default::default()
213226
}),
214227
})
@@ -427,55 +440,69 @@ impl Client {
427440
body: Option<Arc<SegmentedBytes>>,
428441
retry: bool,
429442
) -> Result<reqwest::Response, Error> {
430-
let url = self.shared.base_url.build_url(
443+
let mut url = self.shared.base_url.build_url(
431444
method,
432445
region,
433446
query_params,
434447
bucket_name,
435448
object_name,
436449
)?;
450+
let mut extensions = http::Extensions::default();
451+
headers.add("Host", url.host_header_value());
437452

438-
{
439-
headers.add("Host", url.host_header_value());
440-
let sha256: String = match *method {
441-
Method::PUT | Method::POST => {
442-
if !headers.contains_key("Content-Type") {
443-
headers.add("Content-Type", "application/octet-stream");
444-
}
445-
let len: usize = body.as_ref().map_or(0, |b| b.len());
446-
headers.add("Content-Length", len.to_string());
447-
match body {
448-
None => EMPTY_SHA256.into(),
449-
Some(ref v) => {
450-
let clone = v.clone();
451-
async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await
452-
}
453-
}
453+
let sha256: String = match *method {
454+
Method::PUT | Method::POST => {
455+
if !headers.contains_key("Content-Type") {
456+
headers.add("Content-Type", "application/octet-stream");
454457
}
455-
_ => EMPTY_SHA256.into(),
456-
};
457-
headers.add("x-amz-content-sha256", sha256.clone());
458-
459-
let date = utc_now();
460-
headers.add("x-amz-date", to_amz_date(date));
461-
if let Some(p) = &self.shared.provider {
462-
let creds = p.fetch();
463-
if creds.session_token.is_some() {
464-
headers.add("X-Amz-Security-Token", creds.session_token.unwrap());
458+
let len: usize = body.as_ref().map_or(0, |b| b.len());
459+
headers.add("Content-Length", len.to_string());
460+
match body {
461+
None => EMPTY_SHA256.into(),
462+
Some(ref v) => {
463+
let clone = v.clone();
464+
async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await
465+
}
465466
}
466-
sign_v4_s3(
467-
method,
468-
&url.path,
469-
region,
470-
headers,
471-
query_params,
472-
&creds.access_key,
473-
&creds.secret_key,
474-
&sha256,
475-
date,
476-
);
477467
}
468+
_ => EMPTY_SHA256.into(),
469+
};
470+
headers.add("x-amz-content-sha256", sha256.clone());
471+
472+
let date = utc_now();
473+
headers.add("x-amz-date", to_amz_date(date));
474+
475+
self.run_before_signing_hooks(
476+
method,
477+
&mut url,
478+
region,
479+
headers,
480+
query_params,
481+
bucket_name,
482+
object_name,
483+
body.clone(),
484+
&mut extensions,
485+
)
486+
.await?;
487+
488+
if let Some(p) = &self.shared.provider {
489+
let creds = p.fetch();
490+
if creds.session_token.is_some() {
491+
headers.add("X-Amz-Security-Token", creds.session_token.unwrap());
492+
}
493+
sign_v4_s3(
494+
method,
495+
&url.path,
496+
region,
497+
headers,
498+
query_params,
499+
&creds.access_key,
500+
&creds.secret_key,
501+
&sha256,
502+
date,
503+
);
478504
}
505+
479506
let mut req = self.http_client.request(method.clone(), url.to_string());
480507

481508
for (key, values) in headers.iter_all() {
@@ -504,7 +531,7 @@ impl Client {
504531

505532
if (*method == Method::PUT) || (*method == Method::POST) {
506533
//TODO: why-oh-why first collect into a vector and then iterate to a stream?
507-
let bytes_vec: Vec<Bytes> = match body {
534+
let bytes_vec: Vec<Bytes> = match body.clone() {
508535
Some(v) => v.iter().collect(),
509536
None => Vec::new(),
510537
};
@@ -516,8 +543,22 @@ impl Client {
516543
req = req.body(Body::wrap_stream(stream));
517544
}
518545

519-
let resp: reqwest::Response = req.send().await?;
546+
let resp = req.send().await;
547+
548+
self.run_after_execute_hooks(
549+
method,
550+
&url,
551+
region,
552+
headers,
553+
query_params,
554+
bucket_name,
555+
object_name,
556+
&resp,
557+
&mut extensions,
558+
)
559+
.await;
520560

561+
let resp = resp?;
521562
if resp.status().is_success() {
522563
return Ok(resp);
523564
}
@@ -596,12 +637,71 @@ impl Client {
596637
)
597638
.await
598639
}
640+
641+
async fn run_after_execute_hooks(
642+
&self,
643+
method: &Method,
644+
url: &Url,
645+
region: &str,
646+
headers: &mut Multimap,
647+
query_params: &Multimap,
648+
bucket_name: Option<&str>,
649+
object_name: Option<&str>,
650+
resp: &Result<Response, reqwest::Error>,
651+
extensions: &mut http::Extensions,
652+
) {
653+
for hook in self.shared.client_hooks.iter() {
654+
hook.after_execute(
655+
method,
656+
url,
657+
region,
658+
headers,
659+
query_params,
660+
bucket_name,
661+
object_name,
662+
resp,
663+
extensions,
664+
)
665+
.await;
666+
}
667+
}
668+
669+
async fn run_before_signing_hooks(
670+
&self,
671+
method: &Method,
672+
url: &mut Url,
673+
region: &str,
674+
headers: &mut Multimap,
675+
query_params: &Multimap,
676+
bucket_name: Option<&str>,
677+
object_name: Option<&str>,
678+
body: Option<Arc<SegmentedBytes>>,
679+
extensions: &mut http::Extensions,
680+
) -> Result<(), Error> {
681+
for hook in self.shared.client_hooks.iter() {
682+
hook.before_signing_mut(
683+
method,
684+
url,
685+
region,
686+
headers,
687+
query_params,
688+
bucket_name,
689+
object_name,
690+
body.as_deref(),
691+
extensions,
692+
)
693+
.await
694+
.inspect_err(|e| log::warn!("Hook {} failed {e}", hook.name()))?;
695+
}
696+
Ok(())
697+
}
599698
}
600699

601700
#[derive(Clone, Debug, Default)]
602701
pub(crate) struct SharedClientItems {
603702
pub(crate) base_url: BaseUrl,
604703
pub(crate) provider: Option<Arc<dyn Provider + Send + Sync + 'static>>,
704+
client_hooks: Vec<Arc<dyn ClientHooks + Send + Sync + 'static>>,
605705
region_map: DashMap<String, String>,
606706
express: OnceLock<bool>,
607707
}

src/s3/client/hooks.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
pub use http::Extensions;
2+
3+
use http::Method;
4+
use reqwest::Response;
5+
use std::fmt::Debug;
6+
use crate::s3::error::Error;
7+
use crate::s3::http::Url;
8+
use crate::s3::multimap::Multimap;
9+
use crate::s3::segmented_bytes::SegmentedBytes;
10+
11+
#[async_trait::async_trait]
12+
pub trait ClientHooks: Debug {
13+
fn name(&self) -> &'static str;
14+
15+
async fn before_signing_mut(
16+
&self,
17+
_method: &Method,
18+
_url: &mut Url,
19+
_region: &str,
20+
_headers: &mut Multimap,
21+
_query_params: &Multimap,
22+
_bucket_name: Option<&str>,
23+
_object_name: Option<&str>,
24+
_body: Option<&SegmentedBytes>,
25+
_extensions: &mut Extensions,
26+
) -> Result<(), Error> {
27+
Ok(())
28+
}
29+
30+
async fn after_execute(
31+
&self,
32+
_method: &Method,
33+
_url: &Url,
34+
_region: &str,
35+
_headers: &Multimap,
36+
_query_params: &Multimap,
37+
_bucket_name: Option<&str>,
38+
_object_name: Option<&str>,
39+
_resp: &Result<Response, reqwest::Error>,
40+
_extensions: &mut Extensions,
41+
) {
42+
}
43+
}

src/s3/error.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ pub enum Error {
182182
NoClientProvided,
183183
TagDecodingError(String, String),
184184
ContentLengthUnknown,
185+
Hook {
186+
source: Box<dyn std::error::Error + Send + Sync>,
187+
name: String,
188+
},
185189
}
186190

187191
impl std::error::Error for Error {}
@@ -343,6 +347,9 @@ impl fmt::Display for Error {
343347
write!(f, "tag decoding failed: {error_message} on input '{input}'")
344348
}
345349
Error::ContentLengthUnknown => write!(f, "content length is unknown"),
350+
Error::Hook { source, name } => {
351+
write!(f, "{} interceptor failed: '{}'", name, source)
352+
}
346353
}
347354
}
348355
}

0 commit comments

Comments
 (0)