From 93806a485ba8e1c736b1959e5466dc03ae363147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20P=C3=BCtz?= Date: Fri, 13 Jun 2025 14:53:26 +0200 Subject: [PATCH] add clienthooks --- Cargo.toml | 1 + src/s3/client.rs | 184 +++++++++++++++++++++++++++++++---------- src/s3/client/hooks.rs | 43 ++++++++++ src/s3/error.rs | 7 ++ 4 files changed, 193 insertions(+), 42 deletions(-) create mode 100644 src/s3/client/hooks.rs diff --git a/Cargo.toml b/Cargo.toml index a333ba3b..f28d7925 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ version = "0.12.18" default-features = false features = ["stream"] + [features] default = ["default-tls", "dep:hmac", "dep:sha2"] default-tls = ["reqwest/default-tls"] diff --git a/src/s3/client.rs b/src/s3/client.rs index 79cc1edb..631a32aa 100644 --- a/src/s3/client.rs +++ b/src/s3/client.rs @@ -15,6 +15,7 @@ //! S3 client to perform bucket and object operations +use std::fmt::Debug; use std::fs::File; use std::io::prelude::*; use std::mem; @@ -22,9 +23,10 @@ use std::path::{Path, PathBuf}; use std::sync::{Arc, OnceLock}; use crate::s3::builders::{BucketExists, ComposeSource}; +pub use crate::s3::client::hooks::RequestLifecycleHooks; use crate::s3::creds::Provider; use crate::s3::error::{Error, ErrorCode, ErrorResponse}; -use crate::s3::http::BaseUrl; +use crate::s3::http::{BaseUrl, Url}; use crate::s3::multimap::{Multimap, MultimapExt}; use crate::s3::response::a_response_traits::{HasEtagFromHeaders, HasS3Fields}; use crate::s3::response::*; @@ -35,9 +37,10 @@ use crate::s3::utils::{EMPTY_SHA256, sha256_hash_sb, to_amz_date, utc_now}; use bytes::Bytes; use dashmap::DashMap; use http::HeaderMap; -use hyper::http::Method; +pub use hyper::http::Method; use rand::Rng; use reqwest::Body; +pub use reqwest::{Error as ReqwestError, Response}; mod append_object; mod bucket_exists; @@ -69,6 +72,7 @@ mod get_object_tagging; mod get_presigned_object_url; mod get_presigned_post_form_data; mod get_region; +pub mod hooks; mod list_buckets; mod list_objects; mod listen_bucket_notification; @@ -123,6 +127,7 @@ pub const MAX_MULTIPART_COUNT: u16 = 10_000; pub struct ClientBuilder { base_url: BaseUrl, provider: Option>, + client_hooks: Vec>, ssl_cert_file: Option, ignore_cert_check: Option, app_info: Option<(String, String)>, @@ -138,6 +143,13 @@ impl ClientBuilder { } } + /// Add a client hook to the builder. Hooks will be called after each other in + /// order they were added. + pub fn hook(mut self, hooks: Arc) -> Self { + self.client_hooks.push(hooks); + self + } + /// Set the credential provider. If not, set anonymous access is used. pub fn provider(mut self, provider: Option

) -> Self { self.provider = provider.map(|p| Arc::new(p) as Arc); @@ -209,6 +221,7 @@ impl ClientBuilder { shared: Arc::new(SharedClientItems { base_url: self.base_url, provider: self.provider, + client_hooks: self.client_hooks, ..Default::default() }), }) @@ -427,55 +440,69 @@ impl Client { body: Option>, retry: bool, ) -> Result { - let url = self.shared.base_url.build_url( + let mut url = self.shared.base_url.build_url( method, region, query_params, bucket_name, object_name, )?; + let mut extensions = http::Extensions::default(); + headers.add("Host", url.host_header_value()); - { - headers.add("Host", url.host_header_value()); - let sha256: String = match *method { - Method::PUT | Method::POST => { - if !headers.contains_key("Content-Type") { - headers.add("Content-Type", "application/octet-stream"); - } - let len: usize = body.as_ref().map_or(0, |b| b.len()); - headers.add("Content-Length", len.to_string()); - match body { - None => EMPTY_SHA256.into(), - Some(ref v) => { - let clone = v.clone(); - async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await - } - } + let sha256: String = match *method { + Method::PUT | Method::POST => { + if !headers.contains_key("Content-Type") { + headers.add("Content-Type", "application/octet-stream"); } - _ => EMPTY_SHA256.into(), - }; - headers.add("x-amz-content-sha256", sha256.clone()); - - let date = utc_now(); - headers.add("x-amz-date", to_amz_date(date)); - if let Some(p) = &self.shared.provider { - let creds = p.fetch(); - if creds.session_token.is_some() { - headers.add("X-Amz-Security-Token", creds.session_token.unwrap()); + let len: usize = body.as_ref().map_or(0, |b| b.len()); + headers.add("Content-Length", len.to_string()); + match body { + None => EMPTY_SHA256.into(), + Some(ref v) => { + let clone = v.clone(); + async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await + } } - sign_v4_s3( - method, - &url.path, - region, - headers, - query_params, - &creds.access_key, - &creds.secret_key, - &sha256, - date, - ); } + _ => EMPTY_SHA256.into(), + }; + headers.add("x-amz-content-sha256", sha256.clone()); + + let date = utc_now(); + headers.add("x-amz-date", to_amz_date(date)); + + self.run_before_signing_hooks( + method, + &mut url, + region, + headers, + query_params, + bucket_name, + object_name, + body.clone(), + &mut extensions, + ) + .await?; + + if let Some(p) = &self.shared.provider { + let creds = p.fetch(); + if creds.session_token.is_some() { + headers.add("X-Amz-Security-Token", creds.session_token.unwrap()); + } + sign_v4_s3( + method, + &url.path, + region, + headers, + query_params, + &creds.access_key, + &creds.secret_key, + &sha256, + date, + ); } + let mut req = self.http_client.request(method.clone(), url.to_string()); for (key, values) in headers.iter_all() { @@ -504,7 +531,7 @@ impl Client { if (*method == Method::PUT) || (*method == Method::POST) { //TODO: why-oh-why first collect into a vector and then iterate to a stream? - let bytes_vec: Vec = match body { + let bytes_vec: Vec = match body.clone() { Some(v) => v.iter().collect(), None => Vec::new(), }; @@ -516,8 +543,22 @@ impl Client { req = req.body(Body::wrap_stream(stream)); } - let resp: reqwest::Response = req.send().await?; + let resp = req.send().await; + + self.run_after_execute_hooks( + method, + &url, + region, + headers, + query_params, + bucket_name, + object_name, + &resp, + &mut extensions, + ) + .await; + let resp = resp?; if resp.status().is_success() { return Ok(resp); } @@ -596,12 +637,71 @@ impl Client { ) .await } + + async fn run_after_execute_hooks( + &self, + method: &Method, + url: &Url, + region: &str, + headers: &mut Multimap, + query_params: &Multimap, + bucket_name: Option<&str>, + object_name: Option<&str>, + resp: &Result, + extensions: &mut http::Extensions, + ) { + for hook in self.shared.client_hooks.iter() { + hook.after_execute( + method, + url, + region, + headers, + query_params, + bucket_name, + object_name, + resp, + extensions, + ) + .await; + } + } + + async fn run_before_signing_hooks( + &self, + method: &Method, + url: &mut Url, + region: &str, + headers: &mut Multimap, + query_params: &Multimap, + bucket_name: Option<&str>, + object_name: Option<&str>, + body: Option>, + extensions: &mut http::Extensions, + ) -> Result<(), Error> { + for hook in self.shared.client_hooks.iter() { + hook.before_signing_mut( + method, + url, + region, + headers, + query_params, + bucket_name, + object_name, + body.as_deref(), + extensions, + ) + .await + .inspect_err(|e| log::warn!("Hook {} failed {e}", hook.name()))?; + } + Ok(()) + } } #[derive(Clone, Debug, Default)] pub(crate) struct SharedClientItems { pub(crate) base_url: BaseUrl, pub(crate) provider: Option>, + client_hooks: Vec>, region_map: DashMap, express: OnceLock, } diff --git a/src/s3/client/hooks.rs b/src/s3/client/hooks.rs new file mode 100644 index 00000000..fc46fdf7 --- /dev/null +++ b/src/s3/client/hooks.rs @@ -0,0 +1,43 @@ +pub use http::Extensions; + +use crate::s3::error::Error; +use crate::s3::http::Url; +use crate::s3::multimap::Multimap; +use crate::s3::segmented_bytes::SegmentedBytes; +use http::Method; +use reqwest::Response; +use std::fmt::Debug; + +#[async_trait::async_trait] +pub trait RequestLifecycleHooks: Debug { + fn name(&self) -> &'static str; + + async fn before_signing_mut( + &self, + _method: &Method, + _url: &mut Url, + _region: &str, + _headers: &mut Multimap, + _query_params: &Multimap, + _bucket_name: Option<&str>, + _object_name: Option<&str>, + _body: Option<&SegmentedBytes>, + _extensions: &mut Extensions, + ) -> Result<(), Error> { + Ok(()) + } + + async fn after_execute( + &self, + _method: &Method, + _url: &Url, + _region: &str, + _headers: &Multimap, + _query_params: &Multimap, + _bucket_name: Option<&str>, + _object_name: Option<&str>, + _resp: &Result, + _extensions: &mut Extensions, + ) { + } +} diff --git a/src/s3/error.rs b/src/s3/error.rs index aedfaf6b..b391098f 100644 --- a/src/s3/error.rs +++ b/src/s3/error.rs @@ -182,6 +182,10 @@ pub enum Error { NoClientProvided, TagDecodingError(String, String), ContentLengthUnknown, + Hook { + source: Box, + name: String, + }, } impl std::error::Error for Error {} @@ -343,6 +347,9 @@ impl fmt::Display for Error { write!(f, "tag decoding failed: {error_message} on input '{input}'") } Error::ContentLengthUnknown => write!(f, "content length is unknown"), + Error::Hook { source, name } => { + write!(f, "{} interceptor failed: '{}'", name, source) + } } } }